Skip to content

Commit 554d9ef

Browse files
committed
Fix interpret objective
1 parent 84788a1 commit 554d9ef

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

src/reward_preprocessing/interpret.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def interpret(
171171
else: # Use GAN
172172
# Combine rew net with GAN.
173173
gan = th.load(gan_path, map_location=th.device(device))
174-
model_to_analyse = RewardGeneratorCombo(reward_net=rew_net, generator=gan.generator)
174+
model_to_analyse = RewardGeneratorCombo(
175+
reward_net=rew_net, generator=gan.generator
176+
)
175177

176178
model_to_analyse.eval() # Eval for visualization.
177179

@@ -220,8 +222,8 @@ def interpret(
220222
activation_fn="sigmoid",
221223
)
222224

223-
custom_logger.log(f"Dimensionality reduction (to, from): {nmf.channel_dirs.shape}")
224225
# If these are equal, then of course there is no actual reduction.
226+
custom_logger.log(f"Dimensionality reduction (to, from): {nmf.channel_dirs.shape}")
225227

226228
num_features = nmf.channel_dirs.shape[0]
227229
rows, columns = 2, num_features
@@ -282,6 +284,7 @@ def interpret(
282284
actions = th.tensor(list(range(num_features))).to(device)
283285
assert len(actions) == len(obs)
284286
rews = rew_net(obs.to(device), actions, next_obs.to(device), done=None)
287+
custom_logger.log(f"Rewards: {rews}")
285288

286289
# Use numpy from here.
287290
obs = obs.detach().cpu().numpy()

src/reward_preprocessing/vis/objectives.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
"""Objectives that extend the objectives available in lucent.optvis.objectives"""
2+
from typing import Optional
3+
24
from lucent.optvis.objectives import handle_batch, wrap_objective
35
from lucent.optvis.objectives_util import _extract_act_pos
46
import torch as th
57

68

9+
@wrap_objective()
10+
def max_index_1d(layer: str, i: int, batch: Optional[int] = None):
11+
"""Maximize the value at a specific index in a 1D tensor."""
12+
13+
@handle_batch(batch)
14+
def inner(model):
15+
layer_t = model(layer)
16+
# This is (batch_size, n), we want to maximize the ith element of each batch.
17+
return -layer_t[:, i].mean()
18+
19+
return inner
20+
21+
722
@wrap_objective()
823
def direction_neuron_dim_agnostic(layer, direction, x=None, y=None, batch=None):
924
"""The lucent direction neuron objective, modified to allow 2-dimensional

src/reward_preprocessing/vis/reward_vis.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from typing import Callable, Dict, List, Optional, Union
55

6+
from lucent.optvis import objectives
67
from lucent.optvis.objectives import handle_batch, wrap_objective
78
import lucent.optvis.param as param
89
import lucent.optvis.render as render
@@ -263,10 +264,16 @@ def vis_traditional(
263264
feature_list = [feature_list]
264265

265266
obj = sum(
267+
# Original with cosine similarity:
268+
# [
269+
# objectives_rfi.direction_neuron_dim_agnostic(
270+
# self.layer_name, self.channel_dirs[feature], batch=feature
271+
# )
272+
# for feature in feature_list
273+
# ]
274+
# New:
266275
[
267-
objectives_rfi.direction_neuron_dim_agnostic(
268-
self.layer_name, self.channel_dirs[feature], batch=feature
269-
)
276+
objectives_rfi.max_index_1d(self.layer_name, feature, batch=feature)
270277
for feature in feature_list
271278
]
272279
)

0 commit comments

Comments
 (0)