diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index e184a9db15..6d088ef9b7 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -31,7 +31,10 @@ from captum._utils.progress import progress, SimpleProgress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution -from captum.attr._utils.common import _format_input_baseline +from captum.attr._utils.common import ( + _format_input_baseline, + get_total_features_from_mask, +) from captum.log import log_usage from torch import dtype, Tensor from torch.futures import collect_all, Future @@ -894,7 +897,9 @@ def _attribute_progress_setup( formatted_inputs, feature_mask, **kwargs ) total_forwards = ( - math.ceil(int(sum(feature_counts)) / perturbations_per_eval) + math.ceil( + get_total_features_from_mask(feature_mask) / perturbations_per_eval + ) if enable_cross_tensor_attribution else sum( math.ceil(count / perturbations_per_eval) for count in feature_counts diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 0333637744..74eca9327a 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -390,3 +390,16 @@ def _construct_default_feature_mask( total_features = current_num_features feature_mask = tuple(feature_mask) return feature_mask, total_features + + +def get_total_features_from_mask( + feature_mask: Tuple[Tensor, ...], +) -> int: + """ + Return the numbers of input features based on the total unique + feature IDs/indices in the feature mask. + """ + seen_idxs = set() + for mask in feature_mask: + seen_idxs |= set(torch.unique(mask).tolist()) + return len(seen_idxs) diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index 0c4d5d232c..12e0a24611 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -14,6 +14,7 @@ parse_version, safe_div, ) +from captum.attr._utils.common import get_total_features_from_mask from captum.testing.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, @@ -174,6 +175,16 @@ def test_get_max_feature_index(self) -> None: assert _get_max_feature_index(mask) == 100 + def test_mask_unique_elem(self) -> None: + res = get_total_features_from_mask((torch.tensor([0, 0, 0]),)) + self.assertEqual(res, 1) + res = get_total_features_from_mask((torch.tensor([0, 0, 4]),)) + self.assertEqual(res, 2) + res = get_total_features_from_mask( + (torch.tensor([0, 0, 4]), torch.tensor([0, 4, 5])) + ) + self.assertEqual(res, 3) + class TestParseVersion(BaseTest): def test_parse_version_dev(self) -> None: