Skip to content

Commit fd1ee85

Browse files
authored
Merge pull request #32 from HumanCompatibleAI/print_dataset_vis_rewards
Print out rewards of dataset visualized things
2 parents ab3d103 + dd664f9 commit fd1ee85

File tree

4 files changed

+28
-17
lines changed

4 files changed

+28
-17
lines changed

src/reward_preprocessing/common/utils.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def tensor_to_transition(
155155
trans_tens: th.Tensor,
156156
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
157157
"""Turn a generated 'transition tensor' batch into a batch of bona fide
158-
transitions. Output observations will have channel dim last, activations will be
158+
transitions. Output observations will have channel dim last, actions will be
159159
output as one-hot vectors.
160160
Assumes input transition tensor has values between 0 and 1.
161161
"""
@@ -213,16 +213,6 @@ def forward(self, transition_tensor: th.Tensor) -> th.Tensor:
213213
# tensor_to_transition expects.
214214
obs, act, next_obs = tensor_to_transition(transition_tensor)
215215

216-
# TODO: Remove this once this becomes superfluous.
217-
if self.rew_net.normalize_images:
218-
# Imitation reward nets have this flag which basically decides whether
219-
# observations will be divided by 255 (before being passed to the conv
220-
# layers). If this flag is set they expect images to be between 0 and 255.
221-
# The interpret and lucent code provides images between 0 and 1, so we
222-
# scale up.
223-
obs = obs * 255
224-
next_obs = next_obs * 255
225-
226216
dones = th.zeros_like(obs[:, 0])
227217
return self.rew_net(state=obs, action=act, next_state=next_obs, done=dones)
228218

src/reward_preprocessing/interpret.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def interpret(
169169

170170
device = "cuda" if th.cuda.is_available() else "cpu"
171171

172-
# Load reward not pytorch module
172+
# Load reward net pytorch module
173173
rew_net = th.load(str(reward_path), map_location=th.device(device))
174174

175175
if gan_path is None:
@@ -228,7 +228,7 @@ def interpret(
228228
# input samples are used for dim reduction (if features is not
229229
# None) and for determining the shape of the features.
230230
model_inputs_preprocess=inputs,
231-
activation_fn="sigmoid",
231+
activation_fn="relu",
232232
)
233233

234234
# If these are equal, then of course there is no actual reduction.
@@ -245,7 +245,6 @@ def interpret(
245245
if gan_path is None:
246246
# List of transforms
247247
transforms = _determine_transforms(reg)
248-
249248
# This does the actual interpretability, i.e. it calculates the
250249
# visualizations.
251250
opt_transitions = nmf.vis_traditional(transforms=transforms)
@@ -384,9 +383,24 @@ def param_f():
384383
custom_logger.log(f"Feature {feature_i}")
385384

386385
dataset_thumbnails, indices = nmf.vis_dataset_thumbnail(
387-
feature=feature_i, num_mult=4, expand_mult=1
386+
feature=feature_i,
387+
num_mult=4,
388+
expand_mult=1,
388389
)
389390

391+
if nmf.reducer is None:
392+
# print out rewards
393+
flat_indices = []
394+
for index_list in indices:
395+
flat_indices += index_list
396+
obses, _, next_obses = tensor_to_transition(inputs[flat_indices])
397+
feature_i_rep = th.Tensor([feature_i] * len(flat_indices)).long()
398+
action_i_tens = th.nn.functional.one_hot(
399+
feature_i_rep, num_classes=num_features
400+
).to(device)
401+
rewards = rew_net(obses, action_i_tens, next_obses, done=None)
402+
custom_logger.log(f"Rewards for feature {feature_i}: {rewards}")
403+
390404
# remove opacity channel from dataset thumbnails
391405
np_trans_tens = dataset_thumbnails[:-1, :, :]
392406

src/reward_preprocessing/procgen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ProcgenFinalObsWrapper(gym.Wrapper):
6666
"""Returns the final observation of gym3 procgen environment, correcting for the
6767
fact that Procgen gym environments return the second-to-last observation again
6868
instead of the final observation.
69-
69+
7070
Only works correctly when the 'done' signal coincides with the end of an episode
7171
(which is not the case when using e.g. the seals AutoResetWrapper).
7272
Requires the use of the PavelCz/procgenAISC fork, which adds the 'final_obs' value.

src/reward_preprocessing/vis/reward_vis.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def __init__(
184184
# Apply activation function if specified.
185185
if activation_fn == "sigmoid":
186186
activations = th.sigmoid(activations)
187+
elif activation_fn == "relu":
188+
relu_func = th.nn.ReLU()
189+
activations = relu_func(activations)
187190
elif activation_fn is not None:
188191
raise ValueError(f"Unsupported activation_fn: {activation_fn}")
189192

@@ -286,12 +289,14 @@ def vis_traditional(
286289
for feature in feature_list
287290
]
288291
)
292+
289293
if l2_coeff != 0.0:
290294
if l2_layer_name is None:
291295
raise ValueError(
292296
"l2_layer_name must be specified if l2_coeff is non-zero"
293297
)
294298
obj -= l2_objective(l2_layer_name, l2_coeff)
299+
295300
input_shape = tuple(self.model_inputs_preprocess.shape[1:])
296301

297302
if param_f is None:
@@ -302,6 +307,7 @@ def param_f():
302307
h=input_shape[1],
303308
w=input_shape[2],
304309
batch=len(feature_list),
310+
sd=1,
305311
)
306312

307313
logging.info(f"Performing vis_traditional with transforms: {transforms}")
@@ -472,13 +478,14 @@ def vis_dataset_thumbnail(
472478
pos_indices = argmax_nd(
473479
acts_feature, axes=[1, 2], max_rep=max_rep, max_rep_strict=True
474480
)
475-
# The actual maximum values of the activations, accroding to max_rep setting.
481+
# The actual maximum values of the activations, according to max_rep setting.
476482
acts_single = acts_feature[
477483
range(acts_feature.shape[0]), pos_indices[0], pos_indices[1]
478484
]
479485
# Sort the activations in descending order and take the num_mult**2 strongest
480486
# activations.
481487
obs_indices = np.argsort(-acts_single, axis=0)[: num_mult**2]
488+
482489
# Coordinates of the strongest activation in each observation.
483490
coords = np.array(list(zip(*pos_indices)), dtype=[("h", int), ("w", int)])[
484491
obs_indices

0 commit comments

Comments
 (0)