@@ -167,16 +167,16 @@ def interpret(
167
167
# Imitation reward nets have 4 input args, lucent expects models to only have 1.
168
168
# This wrapper makes it so rew_net accepts a single input which is a
169
169
# transition tensor.
170
- rew_net = TensorTransitionWrapper (rew_net )
170
+ model_to_analyse = TensorTransitionWrapper (rew_net )
171
171
else : # Use GAN
172
172
# Combine rew net with GAN.
173
173
gan = th .load (gan_path , map_location = th .device (device ))
174
- rew_net = RewardGeneratorCombo (reward_net = rew_net , generator = gan .generator )
174
+ model_to_analyse = RewardGeneratorCombo (reward_net = rew_net , generator = gan .generator )
175
175
176
- rew_net .eval () # Eval for visualization.
176
+ model_to_analyse .eval () # Eval for visualization.
177
177
178
178
custom_logger .log ("Available layers:" )
179
- custom_logger .log (get_model_layers (rew_net ))
179
+ custom_logger .log (get_model_layers (model_to_analyse ))
180
180
181
181
# Load the inputs into the model that are used to do dimensionality reduction and
182
182
# getting the shape of activations.
@@ -210,7 +210,6 @@ def interpret(
210
210
# In our case this is one of the following:
211
211
# - A reward net that has been wrapped, so it accepts transition tensors.
212
212
# - A combo of GAN and reward net that accepts latent inputs vectors.
213
- model_to_analyse = rew_net
214
213
nmf = LayerNMF (
215
214
model = model_to_analyse ,
216
215
features = num_features ,
@@ -239,14 +238,16 @@ def interpret(
239
238
# This does the actual interpretability, i.e. it calculates the
240
239
# visualizations.
241
240
opt_transitions = nmf .vis_traditional (transforms = transforms )
242
- # This gives as an array that optimizes the objectives, in the shape of the
241
+ # This gives us an array that optimizes the objectives, in the shape of the
243
242
# input which is a transition tensor. However, lucent helpfully transposes
244
243
# the output such that the channel dimension is last. Our functions expect
245
244
# channel dim before spatial dims, so we need to transpose it back.
246
245
opt_transitions = opt_transitions .transpose (0 , 3 , 1 , 2 )
246
+ # In the following we need opt_transitions to be a pytorch tensor.
247
+ opt_transitions = th .tensor (opt_transitions )
247
248
# Split the optimized transitions, one for each feature, into separate
248
249
# observations and actions. This function only works with torch tensors.
249
- obs , acts , next_obs = tensor_to_transition (th . tensor ( opt_transitions ) )
250
+ obs , acts , next_obs = tensor_to_transition (opt_transitions )
250
251
# obs and next_obs output have channel dim last.
251
252
# acts is output as one-hot vector.
252
253
else :
@@ -265,15 +266,36 @@ def interpret(
265
266
opt_transitions = gan .generator (opt_latent_th )
266
267
obs , acts , next_obs = tensor_to_transition (opt_transitions )
267
268
269
+ # What reward does the model output for these generated transitions?
270
+ # (done isn't used in the reward function)
271
+ # There are three possible options here:
272
+ # - The reward net does not use action -> it does not matter what we pass as
273
+ # action.
274
+ # - The reward net does use action, and we are optimizing an intermediate layer
275
+ # -> since action is only used on the final layer (to choose which of the 15
276
+ # heads has the correct reward), it does not matter what we pass as action.
277
+ # - The reward net does use action, and we are optimizing the final layer
278
+ # -> the action index of the action corresponds to the index of the feature.
279
+ # Note that since actions is only used to choose which head to use, there are no
280
+ # gradients from the reward to the action. Consequently, acts in opt_latent is
281
+ # meaningless.
282
+ actions = th .tensor (list (range (num_features ))).to (device )
283
+ assert len (actions ) == len (obs )
284
+ rews = rew_net (obs .to (device ), actions , next_obs .to (device ), done = None )
285
+
268
286
# Use numpy from here.
269
287
obs = obs .detach ().cpu ().numpy ()
270
288
next_obs = next_obs .detach ().cpu ().numpy ()
289
+ rews = rews .detach ().cpu ().numpy ()
271
290
272
291
# We want to plot the name of the action, if applicable.
273
292
features_are_actions = _determine_features_are_actions (nmf , layer_name )
274
293
275
294
# Set of images, one for each feature, add each to plot
276
295
for feature_i in range (next_obs .shape [0 ]):
296
+ # Log the rewards
297
+ custom_logger .record (f"reward_feature_{ feature_i :02} " , rews [feature_i ])
298
+ # Log the images
277
299
sub_img_obs = obs [feature_i ]
278
300
sub_img_next_obs = next_obs [feature_i ]
279
301
_log_single_transition_wandb (
0 commit comments