@@ -408,88 +408,175 @@ def infidelity(
408408 >>> # Computes infidelity score for saliency maps
409409 >>> infid = infidelity(net, perturb_fn, input, attribution)
410410 """
411+ # perform argument formattings
412+ inputs = _format_tensor_into_tuples (inputs ) # type: ignore
413+ if baselines is not None :
414+ baselines = _format_baseline (baselines , cast (Tuple [Tensor , ...], inputs ))
415+ additional_forward_args = _format_additional_forward_args (additional_forward_args )
416+ attributions = _format_tensor_into_tuples (attributions ) # type: ignore
411417
412- def _generate_perturbations (
413- current_n_perturb_samples : int ,
414- ) -> Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]:
415- r"""
416- The perturbations are generated for each example
417- `current_n_perturb_samples` times.
418+ # Make sure that inputs and corresponding attributions have matching sizes.
419+ assert len (inputs ) == len (attributions ), (
420+ """The number of tensors in the inputs and
421+ attributions must match. Found number of tensors in the inputs is: {} and in the
422+ attributions: {}"""
423+ ).format (len (inputs ), len (attributions ))
424+ for inp , attr in zip (inputs , attributions ):
425+ assert inp .shape == attr .shape , (
426+ """Inputs and attributions must have
427+ matching shapes. One of the input tensor's shape is {} and the
428+ attribution tensor's shape is: {}"""
429+ # pyre-fixme[16]: Module `attr` has no attribute `shape`.
430+ ).format (inp .shape , attr .shape )
418431
419- For performance reasons we are not calling `perturb_func` on each example but
420- on a batch that contains `current_n_perturb_samples`
421- repeated instances per example.
422- """
432+ bsz = inputs [0 ].size (0 )
423433
424- # pyre-fixme[3]: Return type must be annotated.
425- def call_perturb_func ():
426- r""" """
427- baselines_pert = None
428- inputs_pert : Union [Tensor , Tuple [Tensor , ...]]
429- if len (inputs_expanded ) == 1 :
430- inputs_pert = inputs_expanded [0 ]
431- if baselines_expanded is not None :
432- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
433- # parameter.
434- baselines_pert = cast (Tuple , baselines_expanded )[0 ]
435- else :
436- inputs_pert = inputs_expanded
437- baselines_pert = baselines_expanded
438- return (
439- perturb_func (inputs_pert , baselines_pert )
440- if baselines_pert is not None
441- else perturb_func (inputs_pert )
442- )
434+ _next_infidelity_tensors = _make_next_infidelity_tensors_func (
435+ forward_func ,
436+ bsz ,
437+ perturb_func ,
438+ inputs ,
439+ baselines ,
440+ attributions ,
441+ additional_forward_args ,
442+ target ,
443+ normalize ,
444+ )
445+
446+ with torch .no_grad ():
447+ # if not normalize, directly return aggrgated MSE ((a-b)^2,)
448+ # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
449+ agg_tensors = _divide_and_aggregate_metrics (
450+ cast (Tuple [Tensor , ...], inputs ),
451+ n_perturb_samples ,
452+ _next_infidelity_tensors ,
453+ agg_func = _sum_infidelity_tensors ,
454+ max_examples_per_batch = max_examples_per_batch ,
455+ )
456+
457+ if normalize :
458+ beta_num = agg_tensors [1 ]
459+ beta_denorm = agg_tensors [0 ]
443460
444- inputs_expanded = tuple (
445- torch .repeat_interleave (input , current_n_perturb_samples , dim = 0 )
446- for input in inputs
461+ beta = safe_div (beta_num , beta_denorm )
462+
463+ infidelity_values = (
464+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
465+ # `int`.
466+ beta ** 2 * agg_tensors [0 ]
467+ - 2 * beta * agg_tensors [1 ]
468+ + agg_tensors [2 ]
447469 )
470+ else :
471+ infidelity_values = agg_tensors [0 ]
448472
449- baselines_expanded = baselines
450- if baselines is not None :
451- baselines_expanded = tuple (
452- (
453- baseline .repeat_interleave (current_n_perturb_samples , dim = 0 )
454- if isinstance (baseline , torch .Tensor )
455- and baseline .shape [0 ] == input .shape [0 ]
456- and baseline .shape [0 ] > 1
457- else baseline
458- )
473+ infidelity_values /= n_perturb_samples
474+
475+ return infidelity_values
476+
477+
478+ def _generate_perturbations (
479+ current_n_perturb_samples : int ,
480+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
481+ perturb_func : Callable ,
482+ inputs : TensorOrTupleOfTensorsGeneric ,
483+ baselines : BaselineType ,
484+ ) -> Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]:
485+ r"""
486+ The perturbations are generated for each example
487+ `current_n_perturb_samples` times.
488+
489+ For performance reasons we are not calling `perturb_func` on each example but
490+ on a batch that contains `current_n_perturb_samples`
491+ repeated instances per example.
492+ """
493+
494+ # pyre-fixme[3]: Return type must be annotated.
495+ def call_perturb_func ():
496+ r""" """
497+ baselines_pert = None
498+ inputs_pert : Union [Tensor , Tuple [Tensor , ...]]
499+ if len (inputs_expanded ) == 1 :
500+ inputs_pert = inputs_expanded [0 ]
501+ if baselines_expanded is not None :
459502 # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
460503 # parameter.
461- for input , baseline in zip (inputs , cast (Tuple , baselines ))
504+ baselines_pert = cast (Tuple , baselines_expanded )[0 ]
505+ else :
506+ inputs_pert = inputs_expanded
507+ baselines_pert = baselines_expanded
508+ return (
509+ perturb_func (inputs_pert , baselines_pert )
510+ if baselines_pert is not None
511+ else perturb_func (inputs_pert )
512+ )
513+
514+ inputs_expanded = tuple (
515+ torch .repeat_interleave (input , current_n_perturb_samples , dim = 0 )
516+ for input in inputs
517+ )
518+
519+ baselines_expanded = baselines
520+ if baselines is not None :
521+ baselines_expanded = tuple (
522+ (
523+ baseline .repeat_interleave (current_n_perturb_samples , dim = 0 )
524+ if isinstance (baseline , torch .Tensor )
525+ and baseline .shape [0 ] == input .shape [0 ]
526+ and baseline .shape [0 ] > 1
527+ else baseline
462528 )
529+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
530+ # parameter.
531+ for input , baseline in zip (inputs , cast (Tuple , baselines ))
532+ )
463533
464- return call_perturb_func ()
465-
466- def _validate_inputs_and_perturbations (
467- inputs : Tuple [Tensor , ...],
468- inputs_perturbed : Tuple [Tensor , ...],
469- perturbations : Tuple [Tensor , ...],
470- ) -> None :
471- # asserts the sizes of the perturbations and inputs
472- assert len (perturbations ) == len (inputs ), (
473- """The number of perturbed
474- inputs and corresponding perturbations must have the same number of
475- elements. Found number of inputs is: {} and perturbations:
476- {}"""
477- ).format (len (perturbations ), len (inputs ))
478-
479- # asserts the shapes of the perturbations and perturbed inputs
480- for perturb , input_perturbed in zip (perturbations , inputs_perturbed ):
481- assert perturb [0 ].shape == input_perturbed [0 ].shape , (
482- """Perturbed input
483- and corresponding perturbation must have the same shape and
484- dimensionality. Found perturbation shape is: {} and the input shape
485- is: {}"""
486- ).format (perturb [0 ].shape , input_perturbed [0 ].shape )
534+ return call_perturb_func ()
535+
536+
537+ def _validate_inputs_and_perturbations (
538+ inputs : Tuple [Tensor , ...],
539+ inputs_perturbed : Tuple [Tensor , ...],
540+ perturbations : Tuple [Tensor , ...],
541+ ) -> None :
542+ # asserts the sizes of the perturbations and inputs
543+ assert len (perturbations ) == len (inputs ), (
544+ """The number of perturbed
545+ inputs and corresponding perturbations must have the same number of
546+ elements. Found number of inputs is: {} and perturbations:
547+ {}"""
548+ ).format (len (perturbations ), len (inputs ))
549+
550+ # asserts the shapes of the perturbations and perturbed inputs
551+ for perturb , input_perturbed in zip (perturbations , inputs_perturbed ):
552+ assert perturb [0 ].shape == input_perturbed [0 ].shape , (
553+ """Perturbed input
554+ and corresponding perturbation must have the same shape and
555+ dimensionality. Found perturbation shape is: {} and the input shape
556+ is: {}"""
557+ ).format (perturb [0 ].shape , input_perturbed [0 ].shape )
558+
559+
560+ def _make_next_infidelity_tensors_func (
561+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
562+ forward_func : Callable ,
563+ bsz : int ,
564+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
565+ perturb_func : Callable ,
566+ inputs : TensorOrTupleOfTensorsGeneric ,
567+ baselines : BaselineType ,
568+ attributions : TensorOrTupleOfTensorsGeneric ,
569+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
570+ additional_forward_args : Any = None ,
571+ target : TargetType = None ,
572+ normalize : bool = False ,
573+ ) -> Callable [[int ], Union [Tuple [Tensor ], Tuple [Tensor , Tensor , Tensor ]]]:
487574
488575 def _next_infidelity_tensors (
489576 current_n_perturb_samples : int ,
490577 ) -> Union [Tuple [Tensor ], Tuple [Tensor , Tensor , Tensor ]]:
491578 perturbations , inputs_perturbed = _generate_perturbations (
492- current_n_perturb_samples
579+ current_n_perturb_samples , perturb_func , inputs , baselines
493580 )
494581
495582 perturbations = _format_tensor_into_tuples (perturbations )
@@ -564,60 +651,10 @@ def _next_infidelity_tensors(
564651 # returns (a-b)^2 if no need to normalize
565652 return ((attr_times_perturb_sums - perturbed_fwd_diffs ).pow (2 ).sum (- 1 ),)
566653
567- # pyre-fixme[3]: Return type must be annotated.
568- # pyre-fixme[2]: Parameter must be annotated.
569- def _sum_infidelity_tensors (agg_tensors , tensors ):
570- return tuple (agg_t + t for agg_t , t in zip (agg_tensors , tensors ))
654+ return _next_infidelity_tensors
571655
572- # perform argument formattings
573- inputs = _format_tensor_into_tuples (inputs ) # type: ignore
574- if baselines is not None :
575- baselines = _format_baseline (baselines , cast (Tuple [Tensor , ...], inputs ))
576- additional_forward_args = _format_additional_forward_args (additional_forward_args )
577- attributions = _format_tensor_into_tuples (attributions ) # type: ignore
578656
579- # Make sure that inputs and corresponding attributions have matching sizes.
580- assert len (inputs ) == len (attributions ), (
581- """The number of tensors in the inputs and
582- attributions must match. Found number of tensors in the inputs is: {} and in the
583- attributions: {}"""
584- ).format (len (inputs ), len (attributions ))
585- for inp , attr in zip (inputs , attributions ):
586- assert inp .shape == attr .shape , (
587- """Inputs and attributions must have
588- matching shapes. One of the input tensor's shape is {} and the
589- attribution tensor's shape is: {}"""
590- # pyre-fixme[16]: Module `attr` has no attribute `shape`.
591- ).format (inp .shape , attr .shape )
592-
593- bsz = inputs [0 ].size (0 )
594- with torch .no_grad ():
595- # if not normalize, directly return aggrgated MSE ((a-b)^2,)
596- # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
597- agg_tensors = _divide_and_aggregate_metrics (
598- cast (Tuple [Tensor , ...], inputs ),
599- n_perturb_samples ,
600- _next_infidelity_tensors ,
601- agg_func = _sum_infidelity_tensors ,
602- max_examples_per_batch = max_examples_per_batch ,
603- )
604-
605- if normalize :
606- beta_num = agg_tensors [1 ]
607- beta_denorm = agg_tensors [0 ]
608-
609- beta = safe_div (beta_num , beta_denorm )
610-
611- infidelity_values = (
612- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
613- # `int`.
614- beta ** 2 * agg_tensors [0 ]
615- - 2 * beta * agg_tensors [1 ]
616- + agg_tensors [2 ]
617- )
618- else :
619- infidelity_values = agg_tensors [0 ]
620-
621- infidelity_values /= n_perturb_samples
622-
623- return infidelity_values
657+ # pyre-fixme[3]: Return type must be annotated.
658+ # pyre-fixme[2]: Parameter must be annotated.
659+ def _sum_infidelity_tensors (agg_tensors , tensors ):
660+ return tuple (agg_t + t for agg_t , t in zip (agg_tensors , tensors ))
0 commit comments