Skip to content

Commit 6ebfecb

Browse files
committed
Clarify types
1 parent ae4cb38 commit 6ebfecb

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

src/reward_preprocessing/interpret.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from PIL import Image
55
from imitation.scripts.common import common as common_config
6+
from imitation.util.logger import HierarchicalLogger
67
from lucent.modelzoo.util import get_model_layers
78
from lucent.optvis import transform
89
import matplotlib
@@ -221,15 +222,15 @@ def interpret(
221222

222223

223224
def plot_img(
224-
columns,
225-
custom_logger,
226-
feature_i,
227-
fig: Optional,
228-
img,
229-
pyplot,
230-
rows,
231-
vis_scale,
232-
wandb_logging,
225+
columns: int,
226+
custom_logger: HierarchicalLogger,
227+
feature_i: int,
228+
fig: Optional[matplotlib.figure.Figure],
229+
img: np.ndarray,
230+
pyplot: bool,
231+
rows: int,
232+
vis_scale: int,
233+
wandb_logging: bool,
233234
):
234235
"""Plot the passed image to pyplot and wandb as appropriate."""
235236
_wandb_log(custom_logger, feature_i, img, vis_scale, wandb_logging)
@@ -239,7 +240,11 @@ def plot_img(
239240

240241

241242
def _wandb_log(
242-
custom_logger, feature_i: int, img: np.ndarray, vis_scale: int, wandb_logging: bool
243+
custom_logger: HierarchicalLogger,
244+
feature_i: int,
245+
img: np.ndarray,
246+
vis_scale: int,
247+
wandb_logging: bool,
243248
):
244249
"""Plot to wandb if wandb logging is enabled."""
245250
if wandb_logging:

0 commit comments

Comments
 (0)