Skip to content

Commit a66e44b

Browse files
committed
Log reward induced by input
1 parent cf29655 commit a66e44b

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

src/reward_preprocessing/interpret.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,16 @@ def interpret(
167167
# Imitation reward nets have 4 input args, lucent expects models to only have 1.
168168
# This wrapper makes it so rew_net accepts a single input which is a
169169
# transition tensor.
170-
rew_net = TensorTransitionWrapper(rew_net)
170+
model_to_analyse = TensorTransitionWrapper(rew_net)
171171
else: # Use GAN
172172
# Combine rew net with GAN.
173173
gan = th.load(gan_path, map_location=th.device(device))
174-
rew_net = RewardGeneratorCombo(reward_net=rew_net, generator=gan.generator)
174+
model_to_analyse = RewardGeneratorCombo(reward_net=rew_net, generator=gan.generator)
175175

176-
rew_net.eval() # Eval for visualization.
176+
model_to_analyse.eval() # Eval for visualization.
177177

178178
custom_logger.log("Available layers:")
179-
custom_logger.log(get_model_layers(rew_net))
179+
custom_logger.log(get_model_layers(model_to_analyse))
180180

181181
# Load the inputs into the model that are used to do dimensionality reduction and
182182
# getting the shape of activations.
@@ -210,7 +210,6 @@ def interpret(
210210
# In our case this is one of the following:
211211
# - A reward net that has been wrapped, so it accepts transition tensors.
212212
# - A combo of GAN and reward net that accepts latent inputs vectors.
213-
model_to_analyse = rew_net
214213
nmf = LayerNMF(
215214
model=model_to_analyse,
216215
features=num_features,
@@ -239,14 +238,16 @@ def interpret(
239238
# This does the actual interpretability, i.e. it calculates the
240239
# visualizations.
241240
opt_transitions = nmf.vis_traditional(transforms=transforms)
242-
# This gives as an array that optimizes the objectives, in the shape of the
241+
# This gives us an array that optimizes the objectives, in the shape of the
243242
# input which is a transition tensor. However, lucent helpfully transposes
244243
# the output such that the channel dimension is last. Our functions expect
245244
# channel dim before spatial dims, so we need to transpose it back.
246245
opt_transitions = opt_transitions.transpose(0, 3, 1, 2)
246+
# In the following we need opt_transitions to be a pytorch tensor.
247+
opt_transitions = th.tensor(opt_transitions)
247248
# Split the optimized transitions, one for each feature, into separate
248249
# observations and actions. This function only works with torch tensors.
249-
obs, acts, next_obs = tensor_to_transition(th.tensor(opt_transitions))
250+
obs, acts, next_obs = tensor_to_transition(opt_transitions)
250251
# obs and next_obs output have channel dim last.
251252
# acts is output as one-hot vector.
252253
else:
@@ -265,15 +266,36 @@ def interpret(
265266
opt_transitions = gan.generator(opt_latent_th)
266267
obs, acts, next_obs = tensor_to_transition(opt_transitions)
267268

269+
# What reward does the model output for these generated transitions?
270+
# (done isn't used in the reward function)
271+
# There are three possible options here:
272+
# - The reward net does not use action -> it does not matter what we pass as
273+
# action.
274+
# - The reward net does use action, and we are optimizing an intermediate layer
275+
# -> since action is only used on the final layer (to choose which of the 15
276+
# heads has the correct reward), it does not matter what we pass as action.
277+
# - The reward net does use action, and we are optimizing the final layer
278+
# -> the action index of the action corresponds to the index of the feature.
279+
# Note that since actions is only used to choose which head to use, there are no
280+
# gradients from the reward to the action. Consequently, acts in opt_latent is
281+
# meaningless.
282+
actions = th.tensor(list(range(num_features))).to(device)
283+
assert len(actions) == len(obs)
284+
rews = rew_net(obs.to(device), actions, next_obs.to(device), done=None)
285+
268286
# Use numpy from here.
269287
obs = obs.detach().cpu().numpy()
270288
next_obs = next_obs.detach().cpu().numpy()
289+
rews = rews.detach().cpu().numpy()
271290

272291
# We want to plot the name of the action, if applicable.
273292
features_are_actions = _determine_features_are_actions(nmf, layer_name)
274293

275294
# Set of images, one for each feature, add each to plot
276295
for feature_i in range(next_obs.shape[0]):
296+
# Log the rewards
297+
custom_logger.record(f"reward_feature_{feature_i:02}", rews[feature_i])
298+
# Log the images
277299
sub_img_obs = obs[feature_i]
278300
sub_img_next_obs = next_obs[feature_i]
279301
_log_single_transition_wandb(

0 commit comments

Comments
 (0)