1
1
import os .path as osp
2
- from typing import Optional
2
+ from typing import Optional , Tuple , Union
3
3
4
4
from PIL import Image
5
5
from imitation .scripts .common import common as common_config
14
14
import wandb
15
15
16
16
from reward_preprocessing .common .utils import (
17
+ RewardGeneratorCombo ,
17
18
TensorTransitionWrapper ,
18
19
rollouts_to_dataloader ,
19
20
tensor_to_transition ,
@@ -40,7 +41,9 @@ def interpret(
40
41
vis_type : str ,
41
42
layer_name : str ,
42
43
num_features : Optional [int ],
43
- gan_path : Optional [str ],
44
+ gan_path : Optional [str ] = None ,
45
+ l2_coeff : Optional [float ] = None ,
46
+ img_save_path : Optional [str ] = None ,
44
47
):
45
48
"""Run visualization for interpretability.
46
49
@@ -74,6 +77,12 @@ def interpret(
74
77
Path to the GAN model. This is used to regularize the output of the
75
78
visualization. If None simply visualize reward net without the use
76
79
of a GAN in the pipeline.
80
+ l2_coeff:
81
+ Strength with which to penalize the L2 norm of generated latent vector
82
+ "visualizations" of a GAN-reward model combination. If gan_path is not None,
83
+ this must also not be None.
84
+ img_save_path:
85
+ Directory to save images in. Must end in a /. If None, do not save images.
77
86
"""
78
87
if limit_num_obs <= 0 :
79
88
raise ValueError (
@@ -82,6 +91,15 @@ def interpret(
82
91
f"I don't think we actually ever want to use all so this is currently not "
83
92
f"implemented."
84
93
)
94
+ if vis_type not in ["dataset" , "traditional" ]:
95
+ raise ValueError (f"Unknown vis_type: { vis_type } " )
96
+ if vis_type == "dataset" and gan_path is not None :
97
+ raise ValueError ("GANs cannot be used with dataset visualization." )
98
+ if gan_path is not None and l2_coeff is None :
99
+ raise ValueError ("When GANs are used, l2_coeff must be set." )
100
+ if img_save_path is not None and img_save_path [- 1 ] != "/" :
101
+ raise ValueError ("img_save_path is not a directory, does not end in /" )
102
+
85
103
# Set up imitation-style logging.
86
104
custom_logger , log_dir = common_config .setup_logging ()
87
105
wandb_logging = "wandb" in common ["log_format_strs" ]
@@ -101,7 +119,8 @@ def interpret(
101
119
rew_net = TensorTransitionWrapper (rew_net )
102
120
else : # Use GAN
103
121
# Combine rew net with GAN.
104
- raise NotImplementedError ()
122
+ gan = th .load (gan_path , map_location = th .device (device ))
123
+ rew_net = RewardGeneratorCombo (reward_net = rew_net , generator = gan .generator )
105
124
106
125
rew_net .eval () # Eval for visualization.
107
126
@@ -129,17 +148,17 @@ def interpret(
129
148
# Ensure loaded data is FloatTensor and not DoubleTensor.
130
149
inputs = inputs .float ()
131
150
else : # When using GAN.
132
- # Inputs should be some samples of input vectors? Not sure if this is the best
133
- # way to do this, there might be better options.
134
- # The important part is that lucent expects 4D tensors as inputs, so increase
135
- # dimensionality accordingly.
136
- raise NotImplementedError ()
151
+ # Inputs are GAN samples
152
+ samples = gan . sample ( limit_num_obs )
153
+ inputs = samples [:, :, None , None ]
154
+ inputs = inputs . to ( device )
155
+ inputs = inputs . float ()
137
156
138
157
# The model to analyse should be a torch module that takes a single input, which
139
158
# should be a torch Tensor.
140
159
# In our case this is one of the following:
141
160
# - A reward net that has been wrapped, so it accepts transition tensors.
142
- # - A combo of GAN and reward net that accepts latent inputs vectors. (TODO)
161
+ # - A combo of GAN and reward net that accepts latent inputs vectors.
143
162
model_to_analyse = rew_net
144
163
nmf = LayerNMF (
145
164
model = model_to_analyse ,
@@ -157,43 +176,73 @@ def interpret(
157
176
num_features = nmf .channel_dirs .shape [0 ]
158
177
rows , columns = 1 , num_features
159
178
if pyplot :
160
- fig = plt .figure (figsize = (columns * 2 , rows * 2 )) # width, height in inches
179
+ col_mult = 4 if vis_type == "traditional" else 2
180
+ # figsize is width, height in inches
181
+ fig = plt .figure (figsize = (columns * col_mult , rows * 2 ))
161
182
else :
162
183
fig = None
163
184
164
185
# Visualize
165
186
if vis_type == "traditional" :
166
- # List of transforms
167
- transforms = [
168
- transform .jitter (2 ), # Jitters input by 2 pixel
169
- ]
170
-
171
- opt_transitions = nmf .vis_traditional (transforms = transforms )
172
- # This gives as an array that optimizes the objectives, in the shape of the
173
- # input which is a transition tensor. However, lucent helpfully transposes the
174
- # output such that the channel dimension is last. Our functions expect channel
175
- # dim before spatial dims, so we need to transpose it back.
176
- opt_transitions = opt_transitions .transpose (0 , 3 , 1 , 2 )
177
- # Split the optimized transitions, one for each feature, into separate
178
- # observations and actions. This function only works with torch tensors.
179
- obs , acts , next_obs = tensor_to_transition (th .tensor (opt_transitions ))
180
- # obs and next_obs output have channel dim last.
181
- # acts is output as one-hot vector.
187
+
188
+ if gan_path is None :
189
+ # List of transforms
190
+ transforms = [
191
+ transform .jitter (2 ), # Jitters input by 2 pixel
192
+ ]
193
+
194
+ opt_transitions = nmf .vis_traditional (transforms = transforms )
195
+ # This gives as an array that optimizes the objectives, in the shape of the
196
+ # input which is a transition tensor. However, lucent helpfully transposes
197
+ # the output such that the channel dimension is last. Our functions expect
198
+ # channel dim before spatial dims, so we need to transpose it back.
199
+ opt_transitions = opt_transitions .transpose (0 , 3 , 1 , 2 )
200
+ # Split the optimized transitions, one for each feature, into separate
201
+ # observations and actions. This function only works with torch tensors.
202
+ obs , acts , next_obs = tensor_to_transition (th .tensor (opt_transitions ))
203
+ # obs and next_obs output have channel dim last.
204
+ # acts is output as one-hot vector.
205
+
206
+ else :
207
+ # We do not require the latent vectors to be transformed before optimizing.
208
+ # However, we do regularize the L2 norm of latent vectors, to ensure the
209
+ # resulting generated images are realistic.
210
+ opt_latent = nmf .vis_traditional (
211
+ transforms = [],
212
+ l2_coeff = l2_coeff ,
213
+ l2_layer_name = "generator_network_latent_vec" ,
214
+ )
215
+ # Now, we put the latent vector thru the generator to produce transition
216
+ # tensors that we can get observations, actions, etc out of
217
+ opt_latent = np .mean (opt_latent , axis = (1 , 2 ))
218
+ opt_latent_th = th .from_numpy (opt_latent ).to (th .device (device ))
219
+ opt_transitions = gan .generator (opt_latent_th )
220
+ obs , acts , next_obs = tensor_to_transition (opt_transitions )
182
221
183
222
# Set of images, one for each feature, add each to plot
184
223
for feature_i in range (next_obs .shape [0 ]):
185
- sub_img = next_obs [feature_i ]
224
+ sub_img_obs = obs [feature_i ].detach ().cpu ().numpy ()
225
+ sub_img_next_obs = next_obs [feature_i ].detach ().cpu ().numpy ()
186
226
plot_img (
187
227
columns ,
188
228
custom_logger ,
189
229
feature_i ,
190
230
fig ,
191
- sub_img ,
231
+ ( sub_img_obs , sub_img_next_obs ) ,
192
232
pyplot ,
193
233
rows ,
194
234
vis_scale ,
195
235
wandb_logging ,
196
236
)
237
+ if img_save_path is not None :
238
+ obs_PIL = array_to_image (sub_img_obs , vis_scale )
239
+ obs_PIL .save (img_save_path + f"{ feature_i } _obs.png" )
240
+ next_obs_PIL = array_to_image (sub_img_next_obs , vis_scale )
241
+ next_obs_PIL .save (img_save_path + f"{ feature_i } _next_obs.png" )
242
+ custom_logger .log (
243
+ f"Saved feature { feature_i } viz in dir { img_save_path } ."
244
+ )
245
+
197
246
elif vis_type == "dataset" :
198
247
for feature_i in range (num_features ):
199
248
custom_logger .log (f"Feature { feature_i } " )
@@ -213,51 +262,93 @@ def interpret(
213
262
vis_scale ,
214
263
wandb_logging ,
215
264
)
216
- else :
217
- raise ValueError (f"Unknown vis_type: { vis_type } ." )
218
265
219
266
if pyplot :
220
267
plt .show ()
221
268
custom_logger .log ("Done with dataset visualization." )
222
269
223
270
271
+ def array_to_image (arr : np .ndarray , scale : int ) -> Image :
272
+ """Take numpy array on [0,1] scale, return PIL image."""
273
+ return Image .fromarray (np .uint8 (arr * 255 ), mode = "RGB" ).resize (
274
+ size = (arr .shape [0 ] * scale , arr .shape [1 ] * scale ),
275
+ resample = Image .NEAREST ,
276
+ )
277
+
278
+
224
279
def plot_img (
225
280
columns : int ,
226
281
custom_logger : HierarchicalLogger ,
227
282
feature_i : int ,
228
283
fig : Optional [matplotlib .figure .Figure ],
229
- img : np .ndarray ,
284
+ img : Union [ Tuple [ np .ndarray , np . ndarray ], np . ndarray ] ,
230
285
pyplot : bool ,
231
286
rows : int ,
232
287
vis_scale : int ,
233
288
wandb_logging : bool ,
234
289
):
235
- """Plot the passed image to pyplot and wandb as appropriate."""
290
+ """Plot the passed image(s) to pyplot and wandb as appropriate."""
236
291
_wandb_log (custom_logger , feature_i , img , vis_scale , wandb_logging )
237
- if fig is not None and pyplot :
238
- fig .add_subplot (rows , columns , feature_i + 1 )
239
- plt .imshow (img )
292
+ if pyplot :
293
+ if isinstance (img , tuple ):
294
+ img_obs = img [0 ]
295
+ img_next_obs = img [1 ]
296
+ fig .add_subplot (rows , columns , 2 * feature_i + 1 )
297
+ plt .imshow (img_obs )
298
+ fig .add_subplot (rows , columns , 2 * feature_i + 2 )
299
+ plt .imshow (img_next_obs )
300
+ else :
301
+ fig .add_subplot (rows , columns , feature_i + 1 )
302
+ plt .imshow (img )
240
303
241
304
242
305
def _wandb_log (
243
306
custom_logger : HierarchicalLogger ,
244
307
feature_i : int ,
245
- img : np .ndarray ,
308
+ img : Union [ Tuple [ np .ndarray , np . ndarray ], np . ndarray ] ,
246
309
vis_scale : int ,
247
310
wandb_logging : bool ,
248
311
):
249
312
"""Plot to wandb if wandb logging is enabled."""
250
313
if wandb_logging :
251
- p_img = Image .fromarray (np .uint8 (img * 255 ), mode = "RGB" ).resize (
252
- size = (img .shape [0 ] * vis_scale , img .shape [1 ] * vis_scale ),
253
- resample = Image .NEAREST ,
254
- )
255
- wb_img = wandb .Image (p_img , caption = f"Feature { feature_i } " )
256
- custom_logger .record (f"feature_{ feature_i } " , wb_img )
314
+ if isinstance (img , tuple ):
315
+ img_obs = img [0 ]
316
+ img_next_obs = img [1 ]
317
+ # TODO(df): check if I have to dump between these
318
+ _wandb_log_ (img_obs , vis_scale , feature_i , "obs" , custom_logger )
319
+ _wandb_log_ (img_next_obs , vis_scale , feature_i , "next_obs" , custom_logger )
320
+ else :
321
+ _wandb_log_ (img , vis_scale , feature_i , "dataset_vis" , custom_logger )
322
+
257
323
# Can't re-use steps unfortunately, so each feature img gets its own step.
258
324
custom_logger .dump (step = feature_i )
259
325
260
326
327
+ def _wandb_log_ (
328
+ arr : np .ndarray ,
329
+ scale : int ,
330
+ feature : int ,
331
+ img_type : str ,
332
+ logger : HierarchicalLogger ,
333
+ ) -> None :
334
+ """Log visualized np.ndarray to wandb using given logger.
335
+
336
+ Args:
337
+ - arr: array to turn into image, save.
338
+ - scale: ratio by which to scale up the image.
339
+ - feature: which number feature is being visualized.
340
+ - img_type: "obs" or "next_obs"
341
+ - logger: logger to use.
342
+ """
343
+ if img_type not in ["obs" , "next_obs" ]:
344
+ err_str = f"img_type should be 'obs' or 'next_obs', but instead is { img_type } "
345
+ raise ValueError (err_str )
346
+
347
+ pil_img = array_to_image (arr , scale )
348
+ wb_img = wandb .Image (pil_img , caption = f"Feature { feature } , { img_type } " )
349
+ logger .record (f"feature_{ feature } _{ img_type } " , wb_img )
350
+
351
+
261
352
def main ():
262
353
observer = FileStorageObserver (osp .join ("output" , "sacred" , "interpret" ))
263
354
interpret_ex .observers .append (observer )
0 commit comments