Skip to content

Commit b1e2fc8

Browse files
igorsugakfacebook-github-bot
authored andcommitted
replace uses of np.ndarray with npt.NDArray (#1389)
Summary: X-link: pytorch/opacus#681 X-link: pytorch/botorch#2586 X-link: pytorch/audio#3846 This replaces uses of `numpy.ndarray` in type annotations with `numpy.typing.NDArray`. In Numpy-1.24.0+ `numpy.ndarray` is annotated as generic type. Without template parameters it triggers static analysis errors: ```counterexample Generic type `ndarray` expects 2 type parameters. ``` `numpy.typing.NDArray` is an alias that provides default template parameters. Differential Revision: D64619891
1 parent ad40160 commit b1e2fc8

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

captum/attr/_utils/visualization.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import matplotlib
99

1010
import numpy as np
11+
import numpy.typing as npt
1112
from matplotlib import cm, colors, pyplot as plt
1213
from matplotlib.axes import Axes
1314
from matplotlib.collections import LineCollection
@@ -47,11 +48,11 @@ class VisualizeSign(Enum):
4748
all = 4
4849

4950

50-
def _prepare_image(attr_visual: ndarray) -> ndarray:
51+
def _prepare_image(attr_visual: npt.NDArray) -> npt.NDArray:
5152
return np.clip(attr_visual.astype(int), 0, 255)
5253

5354

54-
def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
55+
def _normalize_scale(attr: npt.NDArray, scale_factor: float) -> npt.NDArray:
5556
assert scale_factor != 0, "Cannot normalize by scale factor = 0"
5657
if abs(scale_factor) < 1e-5:
5758
warnings.warn(
@@ -64,7 +65,9 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
6465
return np.clip(attr_norm, -1, 1)
6566

6667

67-
def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> float:
68+
def _cumulative_sum_threshold(
69+
values: npt.NDArray, percentile: Union[int, float]
70+
) -> float:
6871
# given values should be non-negative
6972
assert percentile >= 0 and percentile <= 100, (
7073
"Percentile for thresholding must be " "between 0 and 100 inclusive."
@@ -76,11 +79,11 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) ->
7679

7780

7881
def _normalize_attr(
79-
attr: ndarray,
82+
attr: npt.NDArray,
8083
sign: str,
8184
outlier_perc: Union[int, float] = 2,
8285
reduction_axis: Optional[int] = None,
83-
) -> ndarray:
86+
) -> npt.NDArray:
8487
attr_combined = attr
8588
if reduction_axis is not None:
8689
attr_combined = np.sum(attr, axis=reduction_axis)
@@ -130,7 +133,7 @@ def _initialize_cmap_and_vmin_vmax(
130133

131134
def _visualize_original_image(
132135
plt_axis: Axes,
133-
original_image: Optional[ndarray],
136+
original_image: Optional[npt.NDArray],
134137
**kwargs: Any,
135138
) -> None:
136139
assert (
@@ -143,7 +146,7 @@ def _visualize_original_image(
143146

144147
def _visualize_heat_map(
145148
plt_axis: Axes,
146-
norm_attr: ndarray,
149+
norm_attr: npt.NDArray,
147150
cmap: Union[str, Colormap],
148151
vmin: float,
149152
vmax: float,
@@ -155,8 +158,8 @@ def _visualize_heat_map(
155158

156159
def _visualize_blended_heat_map(
157160
plt_axis: Axes,
158-
original_image: ndarray,
159-
norm_attr: ndarray,
161+
original_image: npt.NDArray,
162+
norm_attr: npt.NDArray,
160163
cmap: Union[str, Colormap],
161164
vmin: float,
162165
vmax: float,
@@ -176,8 +179,8 @@ def _visualize_blended_heat_map(
176179
def _visualize_masked_image(
177180
plt_axis: Axes,
178181
sign: str,
179-
original_image: ndarray,
180-
norm_attr: ndarray,
182+
original_image: npt.NDArray,
183+
norm_attr: npt.NDArray,
181184
**kwargs: Any,
182185
) -> None:
183186
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
@@ -190,8 +193,8 @@ def _visualize_masked_image(
190193
def _visualize_alpha_scaling(
191194
plt_axis: Axes,
192195
sign: str,
193-
original_image: ndarray,
194-
norm_attr: ndarray,
196+
original_image: npt.NDArray,
197+
norm_attr: npt.NDArray,
195198
**kwargs: Any,
196199
) -> None:
197200
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
@@ -210,8 +213,8 @@ def _visualize_alpha_scaling(
210213

211214

212215
def visualize_image_attr(
213-
attr: ndarray,
214-
original_image: Optional[ndarray] = None,
216+
attr: npt.NDArray,
217+
original_image: Optional[npt.NDArray] = None,
215218
method: str = "heat_map",
216219
sign: str = "absolute_value",
217220
plt_fig_axis: Optional[Tuple[Figure, Axes]] = None,
@@ -417,8 +420,8 @@ def visualize_image_attr(
417420

418421

419422
def visualize_image_attr_multiple(
420-
attr: ndarray,
421-
original_image: Union[None, ndarray],
423+
attr: npt.NDArray,
424+
original_image: Union[None, npt.NDArray],
422425
methods: List[str],
423426
signs: List[str],
424427
titles: Optional[List[str]] = None,
@@ -526,9 +529,9 @@ def visualize_image_attr_multiple(
526529

527530

528531
def visualize_timeseries_attr(
529-
attr: ndarray,
530-
data: ndarray,
531-
x_values: Optional[ndarray] = None,
532+
attr: npt.NDArray,
533+
data: npt.NDArray,
534+
x_values: Optional[npt.NDArray] = None,
532535
method: str = "overlay_individual",
533536
sign: str = "absolute_value",
534537
channel_labels: Optional[List[str]] = None,

tests/attr/test_gradient_shap.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from typing import cast, Tuple
66

77
import numpy as np
8+
import numpy.typing as npt
89
import torch
910
from captum._utils.typing import Tensor
1011
from captum.attr._core.gradient_shap import GradientShap
1112
from captum.attr._core.integrated_gradients import IntegratedGradients
12-
from numpy import ndarray
1313
from tests.attr.helpers.attribution_delta_util import (
1414
assert_attribution_delta,
1515
assert_delta,
@@ -132,7 +132,7 @@ def generate_baselines_with_inputs(inputs: Tensor) -> Tensor:
132132
inp_shape = cast(Tuple[int, ...], inputs.shape)
133133
return torch.arange(0.0, inp_shape[1] * 2.0).reshape(2, inp_shape[1])
134134

135-
def generate_baselines_returns_array() -> ndarray:
135+
def generate_baselines_returns_array() -> npt.NDArray:
136136
return np.arange(0.0, num_in * 4.0).reshape(4, num_in)
137137

138138
# 10-class classification model

tests/utils/models/linear_models/_test_linear_classifier.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import captum._utils.models.linear_model.model as pytorch_model_module
77
import numpy as np
8+
import numpy.typing as npt
89
import sklearn.datasets as datasets
910
import torch
1011
from tests.helpers.evaluate_linear_model import evaluate
@@ -107,7 +108,7 @@ def compare_to_sk_learn(
107108
o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1)
108109

109110
rel_diff = cast(
110-
np.ndarray,
111+
npt.NDArray,
111112
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[int, Tensor]`.
112113
(sum(o_sklearn.values()) - sum(o_pytorch.values())),
113114
) / abs(sum(o_sklearn.values()))

0 commit comments

Comments
 (0)