diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 6d088ef9b..4b9d49850 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -729,6 +729,7 @@ def attribute_future( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, + enable_cross_tensor_attribution: bool = False, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" @@ -743,17 +744,18 @@ def attribute_future( formatted_additional_forward_args = _format_additional_forward_args( additional_forward_args ) - num_examples = formatted_inputs[0].shape[0] formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs) assert ( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 ), "Perturbations per evaluation must be an integer and at least 1." with torch.no_grad(): + attr_progress = None if show_progress: attr_progress = self._attribute_progress_setup( formatted_inputs, formatted_feature_mask, + enable_cross_tensor_attribution, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -768,7 +770,7 @@ def attribute_future( formatted_additional_forward_args, ) - if show_progress: + if attr_progress is not None: attr_progress.update() processed_initial_eval_fut: Optional[ @@ -788,101 +790,356 @@ def attribute_future( ) ) - # The will be the same amount futures as modified_eval down there, - # since we cannot add up the evaluation result adhoc under async mode. - all_modified_eval_futures: List[ - List[Future[Tuple[List[Tensor], List[Tensor]]]] - ] = [[] for _ in range(len(inputs))] - # Iterate through each feature tensor for ablation - for i in range(len(formatted_inputs)): - # Skip any empty input tensors - if torch.numel(formatted_inputs[i]) == 0: - continue - - for ( - current_inputs, - current_add_args, - current_target, - current_mask, - ) in self._ith_input_ablation_generator( - i, + if enable_cross_tensor_attribution: + # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric + # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got + # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]` + return self._attribute_with_cross_tensor_feature_masks_future( # type: ignore # noqa: E501 line too long + formatted_inputs=formatted_inputs, + formatted_additional_forward_args=formatted_additional_forward_args, + target=target, + baselines=baselines, + formatted_feature_mask=formatted_feature_mask, + attr_progress=attr_progress, + processed_initial_eval_fut=processed_initial_eval_fut, + is_inputs_tuple=is_inputs_tuple, + perturbations_per_eval=perturbations_per_eval, + ) + else: + # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric + # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got + # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]` + return self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long formatted_inputs, formatted_additional_forward_args, target, baselines, formatted_feature_mask, perturbations_per_eval, + attr_progress, + processed_initial_eval_fut, + is_inputs_tuple, **kwargs, - ): - # modified_eval has (n_feature_perturbed * n_outputs) elements - # shape: - # agg mode: (*initial_eval.shape) - # non-agg mode: - # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, - current_inputs, - current_target, - current_add_args, - ) + ) - if show_progress: - attr_progress.update() + def _attribute_with_independent_feature_masks_future( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + perturbations_per_eval: int, + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + processed_initial_eval_fut: Future[ + Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype] + ], + is_inputs_tuple: bool, + **kwargs: Any, + ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: + num_examples = formatted_inputs[0].shape[0] + # The will be the same amount futures as modified_eval down there, + # since we cannot add up the evaluation result adhoc under async mode. + all_modified_eval_futures: List[ + List[Future[Tuple[List[Tensor], List[Tensor]]]] + ] = [[] for _ in range(len(formatted_inputs))] + # Iterate through each feature tensor for ablation + for i in range(len(formatted_inputs)): + # Skip any empty input tensors + if torch.numel(formatted_inputs[i]) == 0: + continue - if not isinstance(modified_eval, torch.Future): - raise AssertionError( - "when using attribute_future, modified_eval should have " - f"Future type rather than {type(modified_eval)}" - ) - if processed_initial_eval_fut is None: - raise AssertionError( - "processed_initial_eval_fut should not be None" - ) + for ( + current_inputs, + current_add_args, + current_target, + current_mask, + ) in self._ith_input_ablation_generator( + i, + formatted_inputs, + formatted_additional_forward_args, + target, + baselines, + formatted_feature_mask, + perturbations_per_eval, + **kwargs, + ): + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + + if attr_progress is not None: + attr_progress.update() + + if not isinstance(modified_eval, torch.Future): + raise AssertionError( + "when using attribute_future, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) + if processed_initial_eval_fut is None: + raise AssertionError( + "processed_initial_eval_fut should not be None" + ) - # Need to collect both initial eval and modified_eval - eval_futs: Future[ - List[ - Future[ - Union[ - Tuple[ - List[Tensor], - List[Tensor], - Tensor, - Tensor, - int, - dtype, - ], + # Need to collect both initial eval and modified_eval + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[ + List[Tensor], + List[Tensor], Tensor, - ] + Tensor, + int, + dtype, + ], + Tensor, ] ] - ] = collect_all( - [ - processed_initial_eval_fut, - modified_eval, - ] - ) + ] + ] = collect_all( + [ + processed_initial_eval_fut, + modified_eval, + ] + ) - ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( - eval_futs.then( - lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long - eval_futs=eval_futs, - current_inputs=current_inputs, - current_mask=current_mask, - i=i, - perturbations_per_eval=perturbations_per_eval, - num_examples=num_examples, - formatted_inputs=formatted_inputs, - ) + ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( + eval_futs.then( + lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long + eval_futs=eval_futs, + current_inputs=current_inputs, + current_mask=current_mask, + i=i, + perturbations_per_eval=perturbations_per_eval, + num_examples=num_examples, + formatted_inputs=formatted_inputs, ) ) + ) - all_modified_eval_futures[i].append(ablated_out_fut) + all_modified_eval_futures[i].append(ablated_out_fut) - if show_progress: - attr_progress.close() + if attr_progress is not None: + attr_progress.close() + + return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + + def _attribute_with_cross_tensor_feature_masks_future( + self, + formatted_inputs: Tuple[Tensor, ...], + formatted_additional_forward_args: Optional[Tuple[object, ...]], + target: TargetType, + baselines: BaselineType, + formatted_feature_mask: Tuple[Tensor, ...], + attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + processed_initial_eval_fut: Future[ + Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype] + ], + is_inputs_tuple: bool, + perturbations_per_eval: int, + **kwargs: Any, + ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: + feature_idx_to_tensor_idx: Dict[int, List[int]] = {} + for i, mask in enumerate(formatted_feature_mask): + for feature_idx in torch.unique(mask): + if feature_idx.item() not in feature_idx_to_tensor_idx: + feature_idx_to_tensor_idx[feature_idx.item()] = [] + feature_idx_to_tensor_idx[feature_idx.item()].append(i) + all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) + + additional_args_repeated: object + if perturbations_per_eval > 1: + # Repeat features and additional args for batch size. + all_features_repeated = tuple( + torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0) + for j in range(len(formatted_inputs)) + ) + additional_args_repeated = ( + _expand_additional_forward_args( + formatted_additional_forward_args, perturbations_per_eval + ) + if formatted_additional_forward_args is not None + else None + ) + target_repeated = _expand_target(target, perturbations_per_eval) + else: + all_features_repeated = formatted_inputs + additional_args_repeated = formatted_additional_forward_args + target_repeated = target + num_examples = formatted_inputs[0].shape[0] + + current_additional_args: object + if isinstance(baselines, tuple): + reshaped = False + reshaped_baselines: list[Union[Tensor, int, float]] = [] + for baseline in baselines: + if isinstance(baseline, Tensor): + reshaped = True + reshaped_baselines.append( + baseline.reshape((1,) + tuple(baseline.shape)) + ) + else: + reshaped_baselines.append(baseline) + baselines = tuple(reshaped_baselines) if reshaped else baselines + + all_modified_eval_futures: List[Future[Tuple[List[Tensor], List[Tensor]]]] = [] + for i in range(0, len(all_feature_idxs), perturbations_per_eval): + current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval] + current_num_ablated_features = min( + perturbations_per_eval, len(current_feature_idxs) + ) + + should_skip = False + all_empty = True + tensor_idx_list = [] + for feature_idx in current_feature_idxs: + tensor_idx_list += feature_idx_to_tensor_idx[feature_idx] + for tensor_idx in set(tensor_idx_list): + if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0: + all_empty = False + if self._min_examples_per_batch_grouped is not None and ( + formatted_inputs[tensor_idx].shape[0] + # pyre-ignore[58]: Type has been narrowed to int + < self._min_examples_per_batch_grouped + ): + should_skip = True + break + if all_empty: + logger.info( + f"Skipping feature group {current_feature_idxs} since all " + f"input tensors are empty" + ) + continue + + if should_skip: + logger.warning( + f"Skipping feature group {current_feature_idxs} since it contains " + f"at least one input tensor with 0th dim less than " + f"{self._min_examples_per_batch_grouped}" + ) + continue + + # Store appropriate inputs and additional args based on batch size. + if current_num_ablated_features != perturbations_per_eval: + current_additional_args = ( + _expand_additional_forward_args( + formatted_additional_forward_args, current_num_ablated_features + ) + if formatted_additional_forward_args is not None + else None + ) + current_target = _expand_target(target, current_num_ablated_features) + expanded_inputs = tuple( + feature_repeated[0 : current_num_ablated_features * num_examples] + for feature_repeated in all_features_repeated + ) + else: + current_additional_args = additional_args_repeated + current_target = target_repeated + expanded_inputs = all_features_repeated - return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + current_inputs, current_masks = ( + self._construct_ablated_input_across_tensors( + expanded_inputs, + formatted_feature_mask, + baselines, + current_feature_idxs, + feature_idx_to_tensor_idx, + current_num_ablated_features, + ) + ) + + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) + modified_eval = _run_forward( + self.forward_func, + current_inputs, + current_target, + current_additional_args, + ) + + if attr_progress is not None: + attr_progress.update() + + if not isinstance(modified_eval, torch.Future): + raise AssertionError( + "when using attribute_future, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) + + # Need to collect both initial eval and modified_eval + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[ + List[Tensor], + List[Tensor], + Tensor, + Tensor, + int, + dtype, + ], + Tensor, + ] + ] + ] + ] = collect_all( + [ + processed_initial_eval_fut, + modified_eval, + ] + ) + + ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = eval_futs.then( + lambda eval_futs, current_inputs=current_inputs, current_mask=current_masks, i=i: self._eval_fut_to_ablated_out_fut_cross_tensor( # type: ignore # noqa: E501 line too long + eval_futs=eval_futs, + current_inputs=current_inputs, + current_mask=current_mask, + perturbations_per_eval=perturbations_per_eval, + num_examples=num_examples, + ) + ) + + all_modified_eval_futures.append(ablated_out_fut) + + if attr_progress is not None: + attr_progress.close() + + return self._generate_async_result_cross_tensor( + all_modified_eval_futures, + is_inputs_tuple, + ) + + def _fut_tuple_to_accumulate_fut_list_cross_tensor( + self, + total_attrib: List[Tensor], + weights: List[Tensor], + fut_tuple: Future[Tuple[List[Tensor], List[Tensor]]], + ) -> None: + try: + # process_ablated_out_* already accumlates the total attribution. + # Just get the latest value + attribs, this_weights = fut_tuple.value() + total_attrib[:] = attribs + weights[:] = this_weights + except FeatureAblationFutureError as e: + raise FeatureAblationFutureError( + "_fut_tuple_to_accumulate_fut_list_cross_tensor failed" + ) from e # pyre-fixme[3] return type must be annotated def _attribute_progress_setup( @@ -913,7 +1170,6 @@ def _attribute_progress_setup( def _eval_fut_to_ablated_out_fut( self, - # pyre-ignore Invalid type parameters [24] eval_futs: Future[List[Future[List[object]]]], current_inputs: Tuple[Tensor, ...], current_mask: Tensor, @@ -975,6 +1231,94 @@ def _eval_fut_to_ablated_out_fut( ) from e return result + def _generate_async_result_cross_tensor( + self, + futs: List[Future[Tuple[List[Tensor], List[Tensor]]]], + is_inputs_tuple: bool, + ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: + accumulate_fut_list: List[Future[None]] = [] + total_attrib: List[Tensor] = [] + weights: List[Tensor] = [] + + for fut_tuple in futs: + accumulate_fut_list.append( + fut_tuple.then( + lambda fut_tuple: self._fut_tuple_to_accumulate_fut_list_cross_tensor( # noqa: E501 line too long + total_attrib, weights, fut_tuple + ) + ) + ) + + result_fut = collect_all(accumulate_fut_list).then( + lambda x: self._generate_result( + total_attrib, + weights, + is_inputs_tuple, + ) + ) + + return result_fut + + def _eval_fut_to_ablated_out_fut_cross_tensor( + self, + eval_futs: Future[List[Future[List[object]]]], + current_inputs: Tuple[Tensor, ...], + current_mask: Tuple[Optional[Tensor], ...], + perturbations_per_eval: int, + num_examples: int, + ) -> Tuple[List[Tensor], List[Tensor]]: + try: + modified_eval = cast(Tensor, eval_futs.value()[1].value()) + initial_eval_tuple = cast( + Tuple[ + List[Tensor], + List[Tensor], + Tensor, + Tensor, + int, + dtype, + ], + eval_futs.value()[0].value(), + ) + if len(initial_eval_tuple) != 6: + raise AssertionError( + "eval_fut_to_ablated_out_fut_cross_tensor: " + "initial_eval_tuple should have 6 elements: " + "total_attrib, weights, initial_eval, " + "flattened_initial_eval, n_outputs, attrib_type " + ) + if not isinstance(modified_eval, Tensor): + raise AssertionError( + "_eval_fut_to_ablated_out_fut_cross_tensor: " + "modified eval should be a Tensor" + ) + ( + total_attrib, + weights, + initial_eval, + flattened_initial_eval, + n_outputs, + attrib_type, + ) = initial_eval_tuple + total_attrib, weights = self._process_ablated_out_full( + modified_eval=modified_eval, + inputs=current_inputs, + current_mask=current_mask, + perturbations_per_eval=perturbations_per_eval, + num_examples=num_examples, + initial_eval=initial_eval, + flattened_initial_eval=flattened_initial_eval, + n_outputs=n_outputs, + total_attrib=total_attrib, + weights=weights, + attrib_type=attrib_type, + ) + except FeatureAblationFutureError as e: + raise FeatureAblationFutureError( + "_eval_fut_to_ablated_out_fut_cross_tensor func failed" + ) from e + return total_attrib, weights + def _ith_input_ablation_generator( self, i: int,