8
8
import matplotlib
9
9
10
10
import numpy as np
11
+ import numpy .typing as npt
11
12
from matplotlib import cm , colors , pyplot as plt
12
13
from matplotlib .axes import Axes
13
14
from matplotlib .collections import LineCollection
@@ -47,11 +48,11 @@ class VisualizeSign(Enum):
47
48
all = 4
48
49
49
50
50
- def _prepare_image (attr_visual : ndarray ) -> ndarray :
51
+ def _prepare_image (attr_visual : npt . NDArray ) -> npt . NDArray :
51
52
return np .clip (attr_visual .astype (int ), 0 , 255 )
52
53
53
54
54
- def _normalize_scale (attr : ndarray , scale_factor : float ) -> ndarray :
55
+ def _normalize_scale (attr : npt . NDArray , scale_factor : float ) -> npt . NDArray :
55
56
assert scale_factor != 0 , "Cannot normalize by scale factor = 0"
56
57
if abs (scale_factor ) < 1e-5 :
57
58
warnings .warn (
@@ -64,7 +65,9 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
64
65
return np .clip (attr_norm , - 1 , 1 )
65
66
66
67
67
- def _cumulative_sum_threshold (values : ndarray , percentile : Union [int , float ]) -> float :
68
+ def _cumulative_sum_threshold (
69
+ values : npt .NDArray , percentile : Union [int , float ]
70
+ ) -> float :
68
71
# given values should be non-negative
69
72
assert percentile >= 0 and percentile <= 100 , (
70
73
"Percentile for thresholding must be " "between 0 and 100 inclusive."
@@ -76,11 +79,11 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) ->
76
79
77
80
78
81
def _normalize_attr (
79
- attr : ndarray ,
82
+ attr : npt . NDArray ,
80
83
sign : str ,
81
84
outlier_perc : Union [int , float ] = 2 ,
82
85
reduction_axis : Optional [int ] = None ,
83
- ) -> ndarray :
86
+ ) -> npt . NDArray :
84
87
attr_combined = attr
85
88
if reduction_axis is not None :
86
89
attr_combined = np .sum (attr , axis = reduction_axis )
@@ -130,7 +133,7 @@ def _initialize_cmap_and_vmin_vmax(
130
133
131
134
def _visualize_original_image (
132
135
plt_axis : Axes ,
133
- original_image : Optional [ndarray ],
136
+ original_image : Optional [npt . NDArray ],
134
137
** kwargs : Any ,
135
138
) -> None :
136
139
assert (
@@ -143,7 +146,7 @@ def _visualize_original_image(
143
146
144
147
def _visualize_heat_map (
145
148
plt_axis : Axes ,
146
- norm_attr : ndarray ,
149
+ norm_attr : npt . NDArray ,
147
150
cmap : Union [str , Colormap ],
148
151
vmin : float ,
149
152
vmax : float ,
@@ -155,8 +158,8 @@ def _visualize_heat_map(
155
158
156
159
def _visualize_blended_heat_map (
157
160
plt_axis : Axes ,
158
- original_image : ndarray ,
159
- norm_attr : ndarray ,
161
+ original_image : npt . NDArray ,
162
+ norm_attr : npt . NDArray ,
160
163
cmap : Union [str , Colormap ],
161
164
vmin : float ,
162
165
vmax : float ,
@@ -176,8 +179,8 @@ def _visualize_blended_heat_map(
176
179
def _visualize_masked_image (
177
180
plt_axis : Axes ,
178
181
sign : str ,
179
- original_image : ndarray ,
180
- norm_attr : ndarray ,
182
+ original_image : npt . NDArray ,
183
+ norm_attr : npt . NDArray ,
181
184
** kwargs : Any ,
182
185
) -> None :
183
186
assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -190,8 +193,8 @@ def _visualize_masked_image(
190
193
def _visualize_alpha_scaling (
191
194
plt_axis : Axes ,
192
195
sign : str ,
193
- original_image : ndarray ,
194
- norm_attr : ndarray ,
196
+ original_image : npt . NDArray ,
197
+ norm_attr : npt . NDArray ,
195
198
** kwargs : Any ,
196
199
) -> None :
197
200
assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -210,8 +213,8 @@ def _visualize_alpha_scaling(
210
213
211
214
212
215
def visualize_image_attr (
213
- attr : ndarray ,
214
- original_image : Optional [ndarray ] = None ,
216
+ attr : npt . NDArray ,
217
+ original_image : Optional [npt . NDArray ] = None ,
215
218
method : str = "heat_map" ,
216
219
sign : str = "absolute_value" ,
217
220
plt_fig_axis : Optional [Tuple [Figure , Axes ]] = None ,
@@ -417,8 +420,8 @@ def visualize_image_attr(
417
420
418
421
419
422
def visualize_image_attr_multiple (
420
- attr : ndarray ,
421
- original_image : Union [None , ndarray ],
423
+ attr : npt . NDArray ,
424
+ original_image : Union [None , npt . NDArray ],
422
425
methods : List [str ],
423
426
signs : List [str ],
424
427
titles : Optional [List [str ]] = None ,
@@ -526,9 +529,9 @@ def visualize_image_attr_multiple(
526
529
527
530
528
531
def visualize_timeseries_attr (
529
- attr : ndarray ,
530
- data : ndarray ,
531
- x_values : Optional [ndarray ] = None ,
532
+ attr : npt . NDArray ,
533
+ data : npt . NDArray ,
534
+ x_values : Optional [npt . NDArray ] = None ,
532
535
method : str = "overlay_individual" ,
533
536
sign : str = "absolute_value" ,
534
537
channel_labels : Optional [List [str ]] = None ,
0 commit comments