File tree Expand file tree Collapse file tree 3 files changed +13
-4
lines changed Expand file tree Collapse file tree 3 files changed +13
-4
lines changed Original file line number Diff line number Diff line change @@ -264,7 +264,7 @@ def log_img_wandb(
264
264
)
265
265
else :
266
266
raise ValueError (
267
- f"img must be np.ndarray or PIL.Image.Image, type(img)= { type (img )} "
267
+ f"img must be np.ndarray or PIL.Image.Image, { type (img )= } "
268
268
)
269
269
wb_img = wandb .Image (pil_img , caption = caption )
270
270
logger .record (wandb_key , wb_img )
Original file line number Diff line number Diff line change @@ -110,7 +110,10 @@ def interpret(
110
110
Limit how many of the transitions from `rollout_path` are used for
111
111
dimensionality reduction. The RL Vision paper uses "a few thousand"
112
112
sampled infrequently from rollouts.
113
- pyplot: Whether to plot images as pyplot figures.
113
+ pyplot:
114
+ Whether to plot visualizations as pyplot figures. Set to False when running
115
+ interpret in a non-GUI environment, such as the cluster. In that case, use
116
+ wandb logging or save images to disk.
114
117
vis_scale: Scale the plotted images by this factor.
115
118
vis_type:
116
119
Type of visualization to use. Either "traditional" for gradient-based
@@ -133,6 +136,9 @@ def interpret(
133
136
this must also not be None.
134
137
img_save_path:
135
138
Directory to save images in. Must end in a /. If None, do not save images.
139
+ reg:
140
+ Regularization settings. See reward_preprocessing.scripts.config.interpret
141
+ for defaults.
136
142
"""
137
143
if limit_num_obs <= 0 :
138
144
raise ValueError (
Original file line number Diff line number Diff line change @@ -13,7 +13,10 @@ def defaults():
13
13
# Limit the number of observations to use for dim reduction.
14
14
# The RL Vision paper uses "a few thousand" observations.
15
15
limit_num_obs = 2048
16
- pyplot = False # Plot images as pyplot figures
16
+ # Whether to plot visualizations as pyplot figures. Set to False when running
17
+ # interpret in a non-GUI environment, such as the cluster. In that case, use
18
+ # wandb logging or save images to disk.
19
+ pyplot = False
17
20
vis_scale = 4 # Scale the visualization img by this factor
18
21
vis_type = "traditional" # "traditional" (gradient-based) or "dataset"
19
22
# Name of the layer to visualize. To figure this out run interpret and the
@@ -31,7 +34,7 @@ def defaults():
31
34
# What regularization to use for generated images.
32
35
reg = {
33
36
"no_gan" : {
34
- "jitter" : 8 ,
37
+ "jitter" : 8 , # Jitter for generated images.
35
38
}
36
39
}
37
40
You can’t perform that action at this time.
0 commit comments