Skip to content

Commit 3b54248

Browse files
committed
Merge branch 'try-feature-vis' into extend-interpret
2 parents 554d9ef + bf6ed65 commit 3b54248

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def log_img_wandb(
264264
)
265265
else:
266266
raise ValueError(
267-
f"img must be np.ndarray or PIL.Image.Image, type(img)={type(img)}"
267+
f"img must be np.ndarray or PIL.Image.Image, {type(img)=}"
268268
)
269269
wb_img = wandb.Image(pil_img, caption=caption)
270270
logger.record(wandb_key, wb_img)

src/reward_preprocessing/interpret.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ def interpret(
110110
Limit how many of the transitions from `rollout_path` are used for
111111
dimensionality reduction. The RL Vision paper uses "a few thousand"
112112
sampled infrequently from rollouts.
113-
pyplot: Whether to plot images as pyplot figures.
113+
pyplot:
114+
Whether to plot visualizations as pyplot figures. Set to False when running
115+
interpret in a non-GUI environment, such as the cluster. In that case, use
116+
wandb logging or save images to disk.
114117
vis_scale: Scale the plotted images by this factor.
115118
vis_type:
116119
Type of visualization to use. Either "traditional" for gradient-based
@@ -133,6 +136,9 @@ def interpret(
133136
this must also not be None.
134137
img_save_path:
135138
Directory to save images in. Must end in a /. If None, do not save images.
139+
reg:
140+
Regularization settings. See reward_preprocessing.scripts.config.interpret
141+
for defaults.
136142
"""
137143
if limit_num_obs <= 0:
138144
raise ValueError(

src/reward_preprocessing/scripts/config/interpret.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ def defaults():
1313
# Limit the number of observations to use for dim reduction.
1414
# The RL Vision paper uses "a few thousand" observations.
1515
limit_num_obs = 2048
16-
pyplot = False # Plot images as pyplot figures
16+
# Whether to plot visualizations as pyplot figures. Set to False when running
17+
# interpret in a non-GUI environment, such as the cluster. In that case, use
18+
# wandb logging or save images to disk.
19+
pyplot = False
1720
vis_scale = 4 # Scale the visualization img by this factor
1821
vis_type = "traditional" # "traditional" (gradient-based) or "dataset"
1922
# Name of the layer to visualize. To figure this out run interpret and the
@@ -31,7 +34,7 @@ def defaults():
3134
# What regularization to use for generated images.
3235
reg = {
3336
"no_gan": {
34-
"jitter": 8,
37+
"jitter": 8, # Jitter for generated images.
3538
}
3639
}
3740

0 commit comments

Comments
 (0)