Skip to content

Commit 9340d67

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Reduce complexity of 'infidelity' (#1376)
Summary: Pull Request resolved: #1376 Reduce complexity of 'infidelity' Reviewed By: jjuncho Differential Revision: D64451749 fbshipit-source-id: f016c9f63ba3071b252ec35bb5581d38d4a96d58
1 parent 28c58da commit 9340d67

File tree

1 file changed

+159
-122
lines changed

1 file changed

+159
-122
lines changed

captum/metrics/_core/infidelity.py

Lines changed: 159 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -408,88 +408,175 @@ def infidelity(
408408
>>> # Computes infidelity score for saliency maps
409409
>>> infid = infidelity(net, perturb_fn, input, attribution)
410410
"""
411+
# perform argument formattings
412+
inputs = _format_tensor_into_tuples(inputs) # type: ignore
413+
if baselines is not None:
414+
baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs))
415+
additional_forward_args = _format_additional_forward_args(additional_forward_args)
416+
attributions = _format_tensor_into_tuples(attributions) # type: ignore
411417

412-
def _generate_perturbations(
413-
current_n_perturb_samples: int,
414-
) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]:
415-
r"""
416-
The perturbations are generated for each example
417-
`current_n_perturb_samples` times.
418+
# Make sure that inputs and corresponding attributions have matching sizes.
419+
assert len(inputs) == len(attributions), (
420+
"""The number of tensors in the inputs and
421+
attributions must match. Found number of tensors in the inputs is: {} and in the
422+
attributions: {}"""
423+
).format(len(inputs), len(attributions))
424+
for inp, attr in zip(inputs, attributions):
425+
assert inp.shape == attr.shape, (
426+
"""Inputs and attributions must have
427+
matching shapes. One of the input tensor's shape is {} and the
428+
attribution tensor's shape is: {}"""
429+
# pyre-fixme[16]: Module `attr` has no attribute `shape`.
430+
).format(inp.shape, attr.shape)
418431

419-
For performance reasons we are not calling `perturb_func` on each example but
420-
on a batch that contains `current_n_perturb_samples`
421-
repeated instances per example.
422-
"""
432+
bsz = inputs[0].size(0)
423433

424-
# pyre-fixme[3]: Return type must be annotated.
425-
def call_perturb_func():
426-
r""" """
427-
baselines_pert = None
428-
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
429-
if len(inputs_expanded) == 1:
430-
inputs_pert = inputs_expanded[0]
431-
if baselines_expanded is not None:
432-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
433-
# parameter.
434-
baselines_pert = cast(Tuple, baselines_expanded)[0]
435-
else:
436-
inputs_pert = inputs_expanded
437-
baselines_pert = baselines_expanded
438-
return (
439-
perturb_func(inputs_pert, baselines_pert)
440-
if baselines_pert is not None
441-
else perturb_func(inputs_pert)
442-
)
434+
_next_infidelity_tensors = _make_next_infidelity_tensors_func(
435+
forward_func,
436+
bsz,
437+
perturb_func,
438+
inputs,
439+
baselines,
440+
attributions,
441+
additional_forward_args,
442+
target,
443+
normalize,
444+
)
445+
446+
with torch.no_grad():
447+
# if not normalize, directly return aggrgated MSE ((a-b)^2,)
448+
# else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
449+
agg_tensors = _divide_and_aggregate_metrics(
450+
cast(Tuple[Tensor, ...], inputs),
451+
n_perturb_samples,
452+
_next_infidelity_tensors,
453+
agg_func=_sum_infidelity_tensors,
454+
max_examples_per_batch=max_examples_per_batch,
455+
)
456+
457+
if normalize:
458+
beta_num = agg_tensors[1]
459+
beta_denorm = agg_tensors[0]
443460

444-
inputs_expanded = tuple(
445-
torch.repeat_interleave(input, current_n_perturb_samples, dim=0)
446-
for input in inputs
461+
beta = safe_div(beta_num, beta_denorm)
462+
463+
infidelity_values = (
464+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
465+
# `int`.
466+
beta**2 * agg_tensors[0]
467+
- 2 * beta * agg_tensors[1]
468+
+ agg_tensors[2]
447469
)
470+
else:
471+
infidelity_values = agg_tensors[0]
448472

449-
baselines_expanded = baselines
450-
if baselines is not None:
451-
baselines_expanded = tuple(
452-
(
453-
baseline.repeat_interleave(current_n_perturb_samples, dim=0)
454-
if isinstance(baseline, torch.Tensor)
455-
and baseline.shape[0] == input.shape[0]
456-
and baseline.shape[0] > 1
457-
else baseline
458-
)
473+
infidelity_values /= n_perturb_samples
474+
475+
return infidelity_values
476+
477+
478+
def _generate_perturbations(
479+
current_n_perturb_samples: int,
480+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
481+
perturb_func: Callable,
482+
inputs: TensorOrTupleOfTensorsGeneric,
483+
baselines: BaselineType,
484+
) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]:
485+
r"""
486+
The perturbations are generated for each example
487+
`current_n_perturb_samples` times.
488+
489+
For performance reasons we are not calling `perturb_func` on each example but
490+
on a batch that contains `current_n_perturb_samples`
491+
repeated instances per example.
492+
"""
493+
494+
# pyre-fixme[3]: Return type must be annotated.
495+
def call_perturb_func():
496+
r""" """
497+
baselines_pert = None
498+
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
499+
if len(inputs_expanded) == 1:
500+
inputs_pert = inputs_expanded[0]
501+
if baselines_expanded is not None:
459502
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
460503
# parameter.
461-
for input, baseline in zip(inputs, cast(Tuple, baselines))
504+
baselines_pert = cast(Tuple, baselines_expanded)[0]
505+
else:
506+
inputs_pert = inputs_expanded
507+
baselines_pert = baselines_expanded
508+
return (
509+
perturb_func(inputs_pert, baselines_pert)
510+
if baselines_pert is not None
511+
else perturb_func(inputs_pert)
512+
)
513+
514+
inputs_expanded = tuple(
515+
torch.repeat_interleave(input, current_n_perturb_samples, dim=0)
516+
for input in inputs
517+
)
518+
519+
baselines_expanded = baselines
520+
if baselines is not None:
521+
baselines_expanded = tuple(
522+
(
523+
baseline.repeat_interleave(current_n_perturb_samples, dim=0)
524+
if isinstance(baseline, torch.Tensor)
525+
and baseline.shape[0] == input.shape[0]
526+
and baseline.shape[0] > 1
527+
else baseline
462528
)
529+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
530+
# parameter.
531+
for input, baseline in zip(inputs, cast(Tuple, baselines))
532+
)
463533

464-
return call_perturb_func()
465-
466-
def _validate_inputs_and_perturbations(
467-
inputs: Tuple[Tensor, ...],
468-
inputs_perturbed: Tuple[Tensor, ...],
469-
perturbations: Tuple[Tensor, ...],
470-
) -> None:
471-
# asserts the sizes of the perturbations and inputs
472-
assert len(perturbations) == len(inputs), (
473-
"""The number of perturbed
474-
inputs and corresponding perturbations must have the same number of
475-
elements. Found number of inputs is: {} and perturbations:
476-
{}"""
477-
).format(len(perturbations), len(inputs))
478-
479-
# asserts the shapes of the perturbations and perturbed inputs
480-
for perturb, input_perturbed in zip(perturbations, inputs_perturbed):
481-
assert perturb[0].shape == input_perturbed[0].shape, (
482-
"""Perturbed input
483-
and corresponding perturbation must have the same shape and
484-
dimensionality. Found perturbation shape is: {} and the input shape
485-
is: {}"""
486-
).format(perturb[0].shape, input_perturbed[0].shape)
534+
return call_perturb_func()
535+
536+
537+
def _validate_inputs_and_perturbations(
538+
inputs: Tuple[Tensor, ...],
539+
inputs_perturbed: Tuple[Tensor, ...],
540+
perturbations: Tuple[Tensor, ...],
541+
) -> None:
542+
# asserts the sizes of the perturbations and inputs
543+
assert len(perturbations) == len(inputs), (
544+
"""The number of perturbed
545+
inputs and corresponding perturbations must have the same number of
546+
elements. Found number of inputs is: {} and perturbations:
547+
{}"""
548+
).format(len(perturbations), len(inputs))
549+
550+
# asserts the shapes of the perturbations and perturbed inputs
551+
for perturb, input_perturbed in zip(perturbations, inputs_perturbed):
552+
assert perturb[0].shape == input_perturbed[0].shape, (
553+
"""Perturbed input
554+
and corresponding perturbation must have the same shape and
555+
dimensionality. Found perturbation shape is: {} and the input shape
556+
is: {}"""
557+
).format(perturb[0].shape, input_perturbed[0].shape)
558+
559+
560+
def _make_next_infidelity_tensors_func(
561+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
562+
forward_func: Callable,
563+
bsz: int,
564+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
565+
perturb_func: Callable,
566+
inputs: TensorOrTupleOfTensorsGeneric,
567+
baselines: BaselineType,
568+
attributions: TensorOrTupleOfTensorsGeneric,
569+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
570+
additional_forward_args: Any = None,
571+
target: TargetType = None,
572+
normalize: bool = False,
573+
) -> Callable[[int], Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]]:
487574

488575
def _next_infidelity_tensors(
489576
current_n_perturb_samples: int,
490577
) -> Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]:
491578
perturbations, inputs_perturbed = _generate_perturbations(
492-
current_n_perturb_samples
579+
current_n_perturb_samples, perturb_func, inputs, baselines
493580
)
494581

495582
perturbations = _format_tensor_into_tuples(perturbations)
@@ -564,60 +651,10 @@ def _next_infidelity_tensors(
564651
# returns (a-b)^2 if no need to normalize
565652
return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),)
566653

567-
# pyre-fixme[3]: Return type must be annotated.
568-
# pyre-fixme[2]: Parameter must be annotated.
569-
def _sum_infidelity_tensors(agg_tensors, tensors):
570-
return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors))
654+
return _next_infidelity_tensors
571655

572-
# perform argument formattings
573-
inputs = _format_tensor_into_tuples(inputs) # type: ignore
574-
if baselines is not None:
575-
baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs))
576-
additional_forward_args = _format_additional_forward_args(additional_forward_args)
577-
attributions = _format_tensor_into_tuples(attributions) # type: ignore
578656

579-
# Make sure that inputs and corresponding attributions have matching sizes.
580-
assert len(inputs) == len(attributions), (
581-
"""The number of tensors in the inputs and
582-
attributions must match. Found number of tensors in the inputs is: {} and in the
583-
attributions: {}"""
584-
).format(len(inputs), len(attributions))
585-
for inp, attr in zip(inputs, attributions):
586-
assert inp.shape == attr.shape, (
587-
"""Inputs and attributions must have
588-
matching shapes. One of the input tensor's shape is {} and the
589-
attribution tensor's shape is: {}"""
590-
# pyre-fixme[16]: Module `attr` has no attribute `shape`.
591-
).format(inp.shape, attr.shape)
592-
593-
bsz = inputs[0].size(0)
594-
with torch.no_grad():
595-
# if not normalize, directly return aggrgated MSE ((a-b)^2,)
596-
# else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
597-
agg_tensors = _divide_and_aggregate_metrics(
598-
cast(Tuple[Tensor, ...], inputs),
599-
n_perturb_samples,
600-
_next_infidelity_tensors,
601-
agg_func=_sum_infidelity_tensors,
602-
max_examples_per_batch=max_examples_per_batch,
603-
)
604-
605-
if normalize:
606-
beta_num = agg_tensors[1]
607-
beta_denorm = agg_tensors[0]
608-
609-
beta = safe_div(beta_num, beta_denorm)
610-
611-
infidelity_values = (
612-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
613-
# `int`.
614-
beta**2 * agg_tensors[0]
615-
- 2 * beta * agg_tensors[1]
616-
+ agg_tensors[2]
617-
)
618-
else:
619-
infidelity_values = agg_tensors[0]
620-
621-
infidelity_values /= n_perturb_samples
622-
623-
return infidelity_values
657+
# pyre-fixme[3]: Return type must be annotated.
658+
# pyre-fixme[2]: Parameter must be annotated.
659+
def _sum_infidelity_tensors(agg_tensors, tensors):
660+
return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors))

0 commit comments

Comments
 (0)