8
8
import matplotlib
9
9
10
10
import numpy as np
11
+ import numpy .typing as npt
11
12
from matplotlib import cm , colors , pyplot as plt
12
13
from matplotlib .axes import Axes
13
14
from matplotlib .collections import LineCollection
@@ -47,11 +48,11 @@ class VisualizeSign(Enum):
47
48
all = 4
48
49
49
50
50
- def _prepare_image (attr_visual : ndarray ) -> ndarray :
51
+ def _prepare_image (attr_visual : npt . NDArray ) -> npt . NDArray :
51
52
return np .clip (attr_visual .astype (int ), 0 , 255 )
52
53
53
54
54
- def _normalize_scale (attr : ndarray , scale_factor : float ) -> ndarray :
55
+ def _normalize_scale (attr : npt . NDArray , scale_factor : float ) -> npt . NDArray :
55
56
assert scale_factor != 0 , "Cannot normalize by scale factor = 0"
56
57
if abs (scale_factor ) < 1e-5 :
57
58
warnings .warn (
@@ -64,23 +65,26 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
64
65
return np .clip (attr_norm , - 1 , 1 )
65
66
66
67
67
- def _cumulative_sum_threshold (values : ndarray , percentile : Union [int , float ]) -> float :
68
+ def _cumulative_sum_threshold (
69
+ values : npt .NDArray , percentile : Union [int , float ]
70
+ ) -> float :
68
71
# given values should be non-negative
69
72
assert percentile >= 0 and percentile <= 100 , (
70
73
"Percentile for thresholding must be " "between 0 and 100 inclusive."
71
74
)
72
75
sorted_vals = np .sort (values .flatten ())
73
76
cum_sums = np .cumsum (sorted_vals )
74
77
threshold_id = np .where (cum_sums >= cum_sums [- 1 ] * 0.01 * percentile )[0 ][0 ]
78
+ # pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
75
79
return sorted_vals [threshold_id ]
76
80
77
81
78
82
def _normalize_attr (
79
- attr : ndarray ,
83
+ attr : npt . NDArray ,
80
84
sign : str ,
81
85
outlier_perc : Union [int , float ] = 2 ,
82
86
reduction_axis : Optional [int ] = None ,
83
- ) -> ndarray :
87
+ ) -> npt . NDArray :
84
88
attr_combined = attr
85
89
if reduction_axis is not None :
86
90
attr_combined = np .sum (attr , axis = reduction_axis )
@@ -130,7 +134,7 @@ def _initialize_cmap_and_vmin_vmax(
130
134
131
135
def _visualize_original_image (
132
136
plt_axis : Axes ,
133
- original_image : Optional [ndarray ],
137
+ original_image : Optional [npt . NDArray ],
134
138
** kwargs : Any ,
135
139
) -> None :
136
140
assert (
@@ -143,7 +147,7 @@ def _visualize_original_image(
143
147
144
148
def _visualize_heat_map (
145
149
plt_axis : Axes ,
146
- norm_attr : ndarray ,
150
+ norm_attr : npt . NDArray ,
147
151
cmap : Union [str , Colormap ],
148
152
vmin : float ,
149
153
vmax : float ,
@@ -155,8 +159,8 @@ def _visualize_heat_map(
155
159
156
160
def _visualize_blended_heat_map (
157
161
plt_axis : Axes ,
158
- original_image : ndarray ,
159
- norm_attr : ndarray ,
162
+ original_image : npt . NDArray ,
163
+ norm_attr : npt . NDArray ,
160
164
cmap : Union [str , Colormap ],
161
165
vmin : float ,
162
166
vmax : float ,
@@ -176,8 +180,8 @@ def _visualize_blended_heat_map(
176
180
def _visualize_masked_image (
177
181
plt_axis : Axes ,
178
182
sign : str ,
179
- original_image : ndarray ,
180
- norm_attr : ndarray ,
183
+ original_image : npt . NDArray ,
184
+ norm_attr : npt . NDArray ,
181
185
** kwargs : Any ,
182
186
) -> None :
183
187
assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -190,8 +194,8 @@ def _visualize_masked_image(
190
194
def _visualize_alpha_scaling (
191
195
plt_axis : Axes ,
192
196
sign : str ,
193
- original_image : ndarray ,
194
- norm_attr : ndarray ,
197
+ original_image : npt . NDArray ,
198
+ norm_attr : npt . NDArray ,
195
199
** kwargs : Any ,
196
200
) -> None :
197
201
assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -210,8 +214,8 @@ def _visualize_alpha_scaling(
210
214
211
215
212
216
def visualize_image_attr (
213
- attr : ndarray ,
214
- original_image : Optional [ndarray ] = None ,
217
+ attr : npt . NDArray ,
218
+ original_image : Optional [npt . NDArray ] = None ,
215
219
method : str = "heat_map" ,
216
220
sign : str = "absolute_value" ,
217
221
plt_fig_axis : Optional [Tuple [Figure , Axes ]] = None ,
@@ -417,8 +421,8 @@ def visualize_image_attr(
417
421
418
422
419
423
def visualize_image_attr_multiple (
420
- attr : ndarray ,
421
- original_image : Union [None , ndarray ],
424
+ attr : npt . NDArray ,
425
+ original_image : Union [None , npt . NDArray ],
422
426
methods : List [str ],
423
427
signs : List [str ],
424
428
titles : Optional [List [str ]] = None ,
@@ -526,9 +530,9 @@ def visualize_image_attr_multiple(
526
530
527
531
528
532
def visualize_timeseries_attr (
529
- attr : ndarray ,
530
- data : ndarray ,
531
- x_values : Optional [ndarray ] = None ,
533
+ attr : npt . NDArray ,
534
+ data : npt . NDArray ,
535
+ x_values : Optional [npt . NDArray ] = None ,
532
536
method : str = "overlay_individual" ,
533
537
sign : str = "absolute_value" ,
534
538
channel_labels : Optional [List [str ]] = None ,
0 commit comments