11import os .path as osp
2- from typing import Optional
2+ from typing import Optional , Tuple , Union
33
44from PIL import Image
55from imitation .scripts .common import common as common_config
1414import wandb
1515
1616from reward_preprocessing .common .utils import (
17+ RewardGeneratorCombo ,
1718 TensorTransitionWrapper ,
1819 rollouts_to_dataloader ,
1920 tensor_to_transition ,
@@ -40,7 +41,9 @@ def interpret(
4041 vis_type : str ,
4142 layer_name : str ,
4243 num_features : Optional [int ],
43- gan_path : Optional [str ],
44+ gan_path : Optional [str ] = None ,
45+ l2_coeff : Optional [float ] = None ,
46+ img_save_path : Optional [str ] = None ,
4447):
4548 """Run visualization for interpretability.
4649
@@ -74,6 +77,12 @@ def interpret(
7477 Path to the GAN model. This is used to regularize the output of the
7578 visualization. If None simply visualize reward net without the use
7679 of a GAN in the pipeline.
80+ l2_coeff:
81+ Strength with which to penalize the L2 norm of generated latent vector
82+ "visualizations" of a GAN-reward model combination. If gan_path is not None,
83+ this must also not be None.
84+ img_save_path:
85+ Directory to save images in. Must end in a /. If None, do not save images.
7786 """
7887 if limit_num_obs <= 0 :
7988 raise ValueError (
@@ -82,6 +91,15 @@ def interpret(
8291 f"I don't think we actually ever want to use all so this is currently not "
8392 f"implemented."
8493 )
94+ if vis_type not in ["dataset" , "traditional" ]:
95+ raise ValueError (f"Unknown vis_type: { vis_type } " )
96+ if vis_type == "dataset" and gan_path is not None :
97+ raise ValueError ("GANs cannot be used with dataset visualization." )
98+ if gan_path is not None and l2_coeff is None :
99+ raise ValueError ("When GANs are used, l2_coeff must be set." )
100+ if img_save_path is not None and img_save_path [- 1 ] != "/" :
101+ raise ValueError ("img_save_path is not a directory, does not end in /" )
102+
85103 # Set up imitation-style logging.
86104 custom_logger , log_dir = common_config .setup_logging ()
87105 wandb_logging = "wandb" in common ["log_format_strs" ]
@@ -101,7 +119,8 @@ def interpret(
101119 rew_net = TensorTransitionWrapper (rew_net )
102120 else : # Use GAN
103121 # Combine rew net with GAN.
104- raise NotImplementedError ()
122+ gan = th .load (gan_path , map_location = th .device (device ))
123+ rew_net = RewardGeneratorCombo (reward_net = rew_net , generator = gan .generator )
105124
106125 rew_net .eval () # Eval for visualization.
107126
@@ -129,17 +148,17 @@ def interpret(
129148 # Ensure loaded data is FloatTensor and not DoubleTensor.
130149 inputs = inputs .float ()
131150 else : # When using GAN.
132- # Inputs should be some samples of input vectors? Not sure if this is the best
133- # way to do this, there might be better options.
134- # The important part is that lucent expects 4D tensors as inputs, so increase
135- # dimensionality accordingly.
136- raise NotImplementedError ()
151+ # Inputs are GAN samples
152+ samples = gan . sample ( limit_num_obs )
153+ inputs = samples [:, :, None , None ]
154+ inputs = inputs . to ( device )
155+ inputs = inputs . float ()
137156
138157 # The model to analyse should be a torch module that takes a single input, which
139158 # should be a torch Tensor.
140159 # In our case this is one of the following:
141160 # - A reward net that has been wrapped, so it accepts transition tensors.
142- # - A combo of GAN and reward net that accepts latent inputs vectors. (TODO)
161+ # - A combo of GAN and reward net that accepts latent inputs vectors.
143162 model_to_analyse = rew_net
144163 nmf = LayerNMF (
145164 model = model_to_analyse ,
@@ -157,43 +176,73 @@ def interpret(
157176 num_features = nmf .channel_dirs .shape [0 ]
158177 rows , columns = 1 , num_features
159178 if pyplot :
160- fig = plt .figure (figsize = (columns * 2 , rows * 2 )) # width, height in inches
179+ col_mult = 4 if vis_type == "traditional" else 2
180+ # figsize is width, height in inches
181+ fig = plt .figure (figsize = (columns * col_mult , rows * 2 ))
161182 else :
162183 fig = None
163184
164185 # Visualize
165186 if vis_type == "traditional" :
166- # List of transforms
167- transforms = [
168- transform .jitter (2 ), # Jitters input by 2 pixel
169- ]
170-
171- opt_transitions = nmf .vis_traditional (transforms = transforms )
172- # This gives as an array that optimizes the objectives, in the shape of the
173- # input which is a transition tensor. However, lucent helpfully transposes the
174- # output such that the channel dimension is last. Our functions expect channel
175- # dim before spatial dims, so we need to transpose it back.
176- opt_transitions = opt_transitions .transpose (0 , 3 , 1 , 2 )
177- # Split the optimized transitions, one for each feature, into separate
178- # observations and actions. This function only works with torch tensors.
179- obs , acts , next_obs = tensor_to_transition (th .tensor (opt_transitions ))
180- # obs and next_obs output have channel dim last.
181- # acts is output as one-hot vector.
187+
188+ if gan_path is None :
189+ # List of transforms
190+ transforms = [
191+ transform .jitter (2 ), # Jitters input by 2 pixel
192+ ]
193+
194+ opt_transitions = nmf .vis_traditional (transforms = transforms )
195+ # This gives as an array that optimizes the objectives, in the shape of the
196+ # input which is a transition tensor. However, lucent helpfully transposes
197+ # the output such that the channel dimension is last. Our functions expect
198+ # channel dim before spatial dims, so we need to transpose it back.
199+ opt_transitions = opt_transitions .transpose (0 , 3 , 1 , 2 )
200+ # Split the optimized transitions, one for each feature, into separate
201+ # observations and actions. This function only works with torch tensors.
202+ obs , acts , next_obs = tensor_to_transition (th .tensor (opt_transitions ))
203+ # obs and next_obs output have channel dim last.
204+ # acts is output as one-hot vector.
205+
206+ else :
207+ # We do not require the latent vectors to be transformed before optimizing.
208+ # However, we do regularize the L2 norm of latent vectors, to ensure the
209+ # resulting generated images are realistic.
210+ opt_latent = nmf .vis_traditional (
211+ transforms = [],
212+ l2_coeff = l2_coeff ,
213+ l2_layer_name = "generator_network_latent_vec" ,
214+ )
215+ # Now, we put the latent vector thru the generator to produce transition
216+ # tensors that we can get observations, actions, etc out of
217+ opt_latent = np .mean (opt_latent , axis = (1 , 2 ))
218+ opt_latent_th = th .from_numpy (opt_latent ).to (th .device (device ))
219+ opt_transitions = gan .generator (opt_latent_th )
220+ obs , acts , next_obs = tensor_to_transition (opt_transitions )
182221
183222 # Set of images, one for each feature, add each to plot
184223 for feature_i in range (next_obs .shape [0 ]):
185- sub_img = next_obs [feature_i ]
224+ sub_img_obs = obs [feature_i ].detach ().cpu ().numpy ()
225+ sub_img_next_obs = next_obs [feature_i ].detach ().cpu ().numpy ()
186226 plot_img (
187227 columns ,
188228 custom_logger ,
189229 feature_i ,
190230 fig ,
191- sub_img ,
231+ ( sub_img_obs , sub_img_next_obs ) ,
192232 pyplot ,
193233 rows ,
194234 vis_scale ,
195235 wandb_logging ,
196236 )
237+ if img_save_path is not None :
238+ obs_PIL = array_to_image (sub_img_obs , vis_scale )
239+ obs_PIL .save (img_save_path + f"{ feature_i } _obs.png" )
240+ next_obs_PIL = array_to_image (sub_img_next_obs , vis_scale )
241+ next_obs_PIL .save (img_save_path + f"{ feature_i } _next_obs.png" )
242+ custom_logger .log (
243+ f"Saved feature { feature_i } viz in dir { img_save_path } ."
244+ )
245+
197246 elif vis_type == "dataset" :
198247 for feature_i in range (num_features ):
199248 custom_logger .log (f"Feature { feature_i } " )
@@ -213,51 +262,93 @@ def interpret(
213262 vis_scale ,
214263 wandb_logging ,
215264 )
216- else :
217- raise ValueError (f"Unknown vis_type: { vis_type } ." )
218265
219266 if pyplot :
220267 plt .show ()
221268 custom_logger .log ("Done with dataset visualization." )
222269
223270
271+ def array_to_image (arr : np .ndarray , scale : int ) -> Image :
272+ """Take numpy array on [0,1] scale, return PIL image."""
273+ return Image .fromarray (np .uint8 (arr * 255 ), mode = "RGB" ).resize (
274+ size = (arr .shape [0 ] * scale , arr .shape [1 ] * scale ),
275+ resample = Image .NEAREST ,
276+ )
277+
278+
224279def plot_img (
225280 columns : int ,
226281 custom_logger : HierarchicalLogger ,
227282 feature_i : int ,
228283 fig : Optional [matplotlib .figure .Figure ],
229- img : np .ndarray ,
284+ img : Union [ Tuple [ np .ndarray , np . ndarray ], np . ndarray ] ,
230285 pyplot : bool ,
231286 rows : int ,
232287 vis_scale : int ,
233288 wandb_logging : bool ,
234289):
235- """Plot the passed image to pyplot and wandb as appropriate."""
290+ """Plot the passed image(s) to pyplot and wandb as appropriate."""
236291 _wandb_log (custom_logger , feature_i , img , vis_scale , wandb_logging )
237- if fig is not None and pyplot :
238- fig .add_subplot (rows , columns , feature_i + 1 )
239- plt .imshow (img )
292+ if pyplot :
293+ if isinstance (img , tuple ):
294+ img_obs = img [0 ]
295+ img_next_obs = img [1 ]
296+ fig .add_subplot (rows , columns , 2 * feature_i + 1 )
297+ plt .imshow (img_obs )
298+ fig .add_subplot (rows , columns , 2 * feature_i + 2 )
299+ plt .imshow (img_next_obs )
300+ else :
301+ fig .add_subplot (rows , columns , feature_i + 1 )
302+ plt .imshow (img )
240303
241304
242305def _wandb_log (
243306 custom_logger : HierarchicalLogger ,
244307 feature_i : int ,
245- img : np .ndarray ,
308+ img : Union [ Tuple [ np .ndarray , np . ndarray ], np . ndarray ] ,
246309 vis_scale : int ,
247310 wandb_logging : bool ,
248311):
249312 """Plot to wandb if wandb logging is enabled."""
250313 if wandb_logging :
251- p_img = Image .fromarray (np .uint8 (img * 255 ), mode = "RGB" ).resize (
252- size = (img .shape [0 ] * vis_scale , img .shape [1 ] * vis_scale ),
253- resample = Image .NEAREST ,
254- )
255- wb_img = wandb .Image (p_img , caption = f"Feature { feature_i } " )
256- custom_logger .record (f"feature_{ feature_i } " , wb_img )
314+ if isinstance (img , tuple ):
315+ img_obs = img [0 ]
316+ img_next_obs = img [1 ]
317+ # TODO(df): check if I have to dump between these
318+ _wandb_log_ (img_obs , vis_scale , feature_i , "obs" , custom_logger )
319+ _wandb_log_ (img_next_obs , vis_scale , feature_i , "next_obs" , custom_logger )
320+ else :
321+ _wandb_log_ (img , vis_scale , feature_i , "dataset_vis" , custom_logger )
322+
257323 # Can't re-use steps unfortunately, so each feature img gets its own step.
258324 custom_logger .dump (step = feature_i )
259325
260326
327+ def _wandb_log_ (
328+ arr : np .ndarray ,
329+ scale : int ,
330+ feature : int ,
331+ img_type : str ,
332+ logger : HierarchicalLogger ,
333+ ) -> None :
334+ """Log visualized np.ndarray to wandb using given logger.
335+
336+ Args:
337+ - arr: array to turn into image, save.
338+ - scale: ratio by which to scale up the image.
339+ - feature: which number feature is being visualized.
340+ - img_type: "obs" or "next_obs"
341+ - logger: logger to use.
342+ """
343+ if img_type not in ["obs" , "next_obs" ]:
344+ err_str = f"img_type should be 'obs' or 'next_obs', but instead is { img_type } "
345+ raise ValueError (err_str )
346+
347+ pil_img = array_to_image (arr , scale )
348+ wb_img = wandb .Image (pil_img , caption = f"Feature { feature } , { img_type } " )
349+ logger .record (f"feature_{ feature } _{ img_type } " , wb_img )
350+
351+
261352def main ():
262353 observer = FileStorageObserver (osp .join ("output" , "sacred" , "interpret" ))
263354 interpret_ex .observers .append (observer )
0 commit comments