Skip to content

Commit 0ad88d3

Browse files
committed
Merge branch 'try-feature-vis' into extend-interpret
2 parents 3b54248 + 71665d5 commit 0ad88d3

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,13 @@ def log_img_wandb(
258258
pil_img = array_to_image(img, scale)
259259
elif isinstance(img, PIL.Image.Image):
260260
pil_img = img.resize(
261-
# PIL expects tuple of (width, height), numpy's index 1 is width, 0 height.
261+
# PIL expects tuple of (width, height), as opposed to numpy's
262+
# (height, width).
262263
size=(img.width * scale, img.height * scale),
263264
resample=Image.NEAREST,
264265
)
265266
else:
266-
raise ValueError(
267-
f"img must be np.ndarray or PIL.Image.Image, {type(img)=}"
268-
)
267+
raise ValueError(f"img must be np.ndarray or PIL.Image.Image, {type(img)=}")
269268
wb_img = wandb.Image(pil_img, caption=caption)
270269
logger.record(wandb_key, wb_img)
271270
if step is not None:
@@ -275,7 +274,8 @@ def log_img_wandb(
275274
def array_to_image(arr: np.ndarray, scale: int) -> PIL.Image.Image:
276275
"""Take numpy array on [0,1] scale, return PIL image."""
277276
return Image.fromarray(np.uint8(arr * 255), mode="RGB").resize(
278-
# PIL expects tuple of (width, height), numpy's index 1 is width, 0 height.
277+
# PIL expects tuple of (width, height), numpy's dimension 1 is width, and
278+
# dimension 0 height.
279279
size=(arr.shape[1] * scale, arr.shape[0] * scale),
280280
resample=Image.NEAREST,
281281
)

src/reward_preprocessing/interpret.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ def _determine_features_are_actions(nmf: LayerNMF, layer_name: str) -> bool:
6262
human-understandable action name instead of the feature index."""
6363
# This is the heuristic for determining whether features are actions:
6464
# - If there is no dim reduction
65-
# - If it is one of the layers from the list
65+
# - If the name of the layer suggest that we are analysing the final layer of a
66+
# reward net
6667
# - If the number of features is 15 since that is the number of actions in all
6768
# procgen games
6869
return (
6970
nmf.channel_dirs.shape[0] == nmf.channel_dirs.shape[1]
70-
and layer_name in ["rew_net_cnn_dense_final"]
71+
and layer_name.endswith("dense_final")
7172
and nmf.channel_dirs.shape[0] == 15
7273
)
7374

src/reward_preprocessing/vis/attribution.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@ def get_activations(
1515
# size. Generally the number of inputs would be in the thousands, so we don't want
1616
# to run the entire batch through the model at once.
1717
batch_size = 128
18+
t_acts = []
1819
for i in range(0, len(model_inputs), batch_size):
1920
model(model_inputs[i : i + batch_size])
2021

21-
# Get activations at layer.
22-
t_acts = hook(layer_name)
22+
# Get activations at layer.
23+
act_batch = hook(layer_name)
24+
t_acts.append(act_batch)
25+
26+
t_acts = th.cat(t_acts, dim=0)
27+
assert t_acts.shape[0] == len(model_inputs)
2328

2429
# Reward activations might be 2 dimensional (scalar + batch dimension) e.g.
2530
# for linear layers. In this case we unsqueeze.

0 commit comments

Comments
 (0)