Skip to content

Commit 8ee3699

Browse files
igorsugakfacebook-github-bot
authored andcommitted
Upgrade fbcode/pytorch to Python Scientific Stack 2 (#3845)
Summary: Pull Request resolved: pytorch/audio#3845 Differential Revision: D64008689
1 parent 8dc48a5 commit 8ee3699

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

captum/attr/_utils/visualization.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import matplotlib
99

1010
import numpy as np
11+
import numpy.typing as npt
1112
from matplotlib import cm, colors, pyplot as plt
1213
from matplotlib.axes import Axes
1314
from matplotlib.collections import LineCollection
@@ -47,11 +48,11 @@ class VisualizeSign(Enum):
4748
all = 4
4849

4950

50-
def _prepare_image(attr_visual: ndarray) -> ndarray:
51+
def _prepare_image(attr_visual: npt.NDArray) -> npt.NDArray:
5152
return np.clip(attr_visual.astype(int), 0, 255)
5253

5354

54-
def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
55+
def _normalize_scale(attr: npt.NDArray, scale_factor: float) -> npt.NDArray:
5556
assert scale_factor != 0, "Cannot normalize by scale factor = 0"
5657
if abs(scale_factor) < 1e-5:
5758
warnings.warn(
@@ -64,23 +65,26 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
6465
return np.clip(attr_norm, -1, 1)
6566

6667

67-
def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> float:
68+
def _cumulative_sum_threshold(
69+
values: npt.NDArray, percentile: Union[int, float]
70+
) -> float:
6871
# given values should be non-negative
6972
assert percentile >= 0 and percentile <= 100, (
7073
"Percentile for thresholding must be " "between 0 and 100 inclusive."
7174
)
7275
sorted_vals = np.sort(values.flatten())
7376
cum_sums = np.cumsum(sorted_vals)
7477
threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
78+
# pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
7579
return sorted_vals[threshold_id]
7680

7781

7882
def _normalize_attr(
79-
attr: ndarray,
83+
attr: npt.NDArray,
8084
sign: str,
8185
outlier_perc: Union[int, float] = 2,
8286
reduction_axis: Optional[int] = None,
83-
) -> ndarray:
87+
) -> npt.NDArray:
8488
attr_combined = attr
8589
if reduction_axis is not None:
8690
attr_combined = np.sum(attr, axis=reduction_axis)
@@ -130,7 +134,7 @@ def _initialize_cmap_and_vmin_vmax(
130134

131135
def _visualize_original_image(
132136
plt_axis: Axes,
133-
original_image: Optional[ndarray],
137+
original_image: Optional[npt.NDArray],
134138
**kwargs: Any,
135139
) -> None:
136140
assert (
@@ -143,7 +147,7 @@ def _visualize_original_image(
143147

144148
def _visualize_heat_map(
145149
plt_axis: Axes,
146-
norm_attr: ndarray,
150+
norm_attr: npt.NDArray,
147151
cmap: Union[str, Colormap],
148152
vmin: float,
149153
vmax: float,
@@ -155,8 +159,8 @@ def _visualize_heat_map(
155159

156160
def _visualize_blended_heat_map(
157161
plt_axis: Axes,
158-
original_image: ndarray,
159-
norm_attr: ndarray,
162+
original_image: npt.NDArray,
163+
norm_attr: npt.NDArray,
160164
cmap: Union[str, Colormap],
161165
vmin: float,
162166
vmax: float,
@@ -176,8 +180,8 @@ def _visualize_blended_heat_map(
176180
def _visualize_masked_image(
177181
plt_axis: Axes,
178182
sign: str,
179-
original_image: ndarray,
180-
norm_attr: ndarray,
183+
original_image: npt.NDArray,
184+
norm_attr: npt.NDArray,
181185
**kwargs: Any,
182186
) -> None:
183187
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
@@ -190,8 +194,8 @@ def _visualize_masked_image(
190194
def _visualize_alpha_scaling(
191195
plt_axis: Axes,
192196
sign: str,
193-
original_image: ndarray,
194-
norm_attr: ndarray,
197+
original_image: npt.NDArray,
198+
norm_attr: npt.NDArray,
195199
**kwargs: Any,
196200
) -> None:
197201
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
@@ -210,8 +214,8 @@ def _visualize_alpha_scaling(
210214

211215

212216
def visualize_image_attr(
213-
attr: ndarray,
214-
original_image: Optional[ndarray] = None,
217+
attr: npt.NDArray,
218+
original_image: Optional[npt.NDArray] = None,
215219
method: str = "heat_map",
216220
sign: str = "absolute_value",
217221
plt_fig_axis: Optional[Tuple[Figure, Axes]] = None,
@@ -417,8 +421,8 @@ def visualize_image_attr(
417421

418422

419423
def visualize_image_attr_multiple(
420-
attr: ndarray,
421-
original_image: Union[None, ndarray],
424+
attr: npt.NDArray,
425+
original_image: Union[None, npt.NDArray],
422426
methods: List[str],
423427
signs: List[str],
424428
titles: Optional[List[str]] = None,
@@ -526,9 +530,9 @@ def visualize_image_attr_multiple(
526530

527531

528532
def visualize_timeseries_attr(
529-
attr: ndarray,
530-
data: ndarray,
531-
x_values: Optional[ndarray] = None,
533+
attr: npt.NDArray,
534+
data: npt.NDArray,
535+
x_values: Optional[npt.NDArray] = None,
532536
method: str = "overlay_individual",
533537
sign: str = "absolute_value",
534538
channel_labels: Optional[List[str]] = None,

tests/attr/test_gradient_shap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import cast, Tuple
66

77
import numpy as np
8+
import numpy.typing as npt
89
import torch
910
from captum._utils.typing import Tensor
1011
from captum.attr._core.gradient_shap import GradientShap
@@ -132,7 +133,7 @@ def generate_baselines_with_inputs(inputs: Tensor) -> Tensor:
132133
inp_shape = cast(Tuple[int, ...], inputs.shape)
133134
return torch.arange(0.0, inp_shape[1] * 2.0).reshape(2, inp_shape[1])
134135

135-
def generate_baselines_returns_array() -> ndarray:
136+
def generate_baselines_returns_array() -> npt.NDArray:
136137
return np.arange(0.0, num_in * 4.0).reshape(4, num_in)
137138

138139
# 10-class classification model

tests/utils/models/linear_models/_test_linear_classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import captum._utils.models.linear_model.model as pytorch_model_module
77
import numpy as np
8+
import numpy.typing as npt
89
import sklearn.datasets as datasets
910
import torch
1011
from tests.helpers.evaluate_linear_model import evaluate
@@ -107,7 +108,7 @@ def compare_to_sk_learn(
107108
o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1)
108109

109110
rel_diff = cast(
110-
np.ndarray,
111+
npt.NDArray,
111112
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[int, Tensor]`.
112113
(sum(o_sklearn.values()) - sum(o_pytorch.values())),
113114
) / abs(sum(o_sklearn.values()))

0 commit comments

Comments
 (0)