Skip to content

Commit ae4cb38

Browse files
committed
Fix undeclared var when disabling pyplot
1 parent a6491b3 commit ae4cb38

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/reward_preprocessing/interpret.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def interpret(
157157
rows, columns = 1, num_features
158158
if pyplot:
159159
fig = plt.figure(figsize=(columns * 2, rows * 2)) # width, height in inches
160+
else:
161+
fig = None
160162

161163
# Visualize
162164
if vis_type == "traditional":
@@ -219,11 +221,19 @@ def interpret(
219221

220222

221223
def plot_img(
222-
columns, custom_logger, feature_i, fig, img, pyplot, rows, vis_scale, wandb_logging
224+
columns,
225+
custom_logger,
226+
feature_i,
227+
fig: Optional,
228+
img,
229+
pyplot,
230+
rows,
231+
vis_scale,
232+
wandb_logging,
223233
):
224234
"""Plot the passed image to pyplot and wandb as appropriate."""
225235
_wandb_log(custom_logger, feature_i, img, vis_scale, wandb_logging)
226-
if pyplot:
236+
if fig is not None and pyplot:
227237
fig.add_subplot(rows, columns, feature_i + 1)
228238
plt.imshow(img)
229239

0 commit comments

Comments
 (0)