Skip to content

Commit e61bf0e

Browse files
authored
Merge pull request #13 from HumanCompatibleAI/clean-interpret
Minor fixes related to interpret
2 parents c93182f + 6ebfecb commit e61bf0e

File tree

4 files changed

+56
-12
lines changed

4 files changed

+56
-12
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Tuple
2+
from typing import List, Optional, Tuple, Union
33

44
import PIL
55
from imitation.data import rollout, types
@@ -26,14 +26,14 @@ def make_transition_to_tensor(num_acts):
2626
def transition_to_tensor(transition):
2727
obs = transition["obs"]
2828
if np.issubdtype(obs.dtype, np.integer):
29-
obs = obs.float() / 255.0
29+
obs = obs / 255.0
3030
# For floats we don't divide by 255.0. In that case we assume the
3131
# observation is already in the range [0, 1].
3232
act = int(transition["acts"])
3333
next_obs = transition["next_obs"]
3434

3535
if np.issubdtype(next_obs.dtype, np.integer):
36-
next_obs = next_obs.float() / 255.0
36+
next_obs = next_obs / 255.0
3737

3838
transp_obs = np.transpose(obs, (2, 0, 1))
3939
obs_height = transp_obs.shape[1]
@@ -70,27 +70,44 @@ def __len__(self):
7070
return self.base_dataset.__len__()
7171

7272

73-
def rollouts_to_dataloader(rollouts_paths, num_acts, batch_size):
73+
def rollouts_to_dataloader(
74+
rollouts_paths: Union[str, List[str]],
75+
num_acts: int,
76+
batch_size: int,
77+
n_trajectories: Optional[int] = None,
78+
):
7479
"""Take saved rollouts of a policy, and produce a dataloader of transitions.
7580
7681
Assumes that observations are (h,w,c)-formatted images and that actions are
7782
discrete.
7883
7984
Args:
80-
rollouts_path: Path to rollouts saved via imitation script, or list of
85+
rollouts_paths: Path to rollouts saved via imitation script, or list of
8186
such paths.
8287
num_acts: Number of actions available to the agent (necessary because
8388
actions are saved as a number, not as a one-hot vector).
8489
batch_size: Int, size of batches that the dataloader serves. Note that
8590
a batch size of 2 will make the GAN algorithm think each batch is
8691
a (data, label) pair, which will mess up training.
92+
n_trajectories: If not None, limit number of trajectories to use.
8793
"""
8894
if isinstance(rollouts_paths, list):
8995
rollouts = []
9096
for path in rollouts_paths:
9197
rollouts += types.load_with_rewards(path)
9298
else:
9399
rollouts = types.load_with_rewards(rollouts_paths)
100+
101+
# Optionally limit the number of trajectories to use, similar to n_expert_demos in
102+
# imitation.scripts.common.demonstrations.
103+
if n_trajectories is not None:
104+
if len(rollouts) < n_trajectories:
105+
raise ValueError(
106+
f"Want to use n_trajectories={n_trajectories} trajectories, but only "
107+
f"{len(rollouts)} are available via {rollouts_paths}.",
108+
)
109+
rollouts = rollouts[:n_trajectories]
110+
94111
flat_rollouts = rollout.flatten_trajectories_with_rew(rollouts)
95112
tensor_rollouts = TransformedDataset(
96113
flat_rollouts, make_transition_to_tensor(num_acts)

src/reward_preprocessing/interpret.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from PIL import Image
55
from imitation.scripts.common import common as common_config
6+
from imitation.util.logger import HierarchicalLogger
67
from lucent.modelzoo.util import get_model_layers
78
from lucent.optvis import transform
89
import matplotlib
@@ -116,11 +117,17 @@ def interpret(
116117
rollouts_paths=rollout_path,
117118
num_acts=15,
118119
batch_size=limit_num_obs,
120+
# This is an upper bound of the number of trajectories we need, since every
121+
# trajectory has at least 1 transition.
122+
n_trajectories=limit_num_obs,
119123
)
120124
# For dim reductions and gettings activations in LayerNMF we want one big batch
121125
# of limit_num_obs transitions. So, we simply use that as batch_size and sample
122126
# the first element from the dataloader.
123-
inputs = next(iter(transition_tensor_dataloader))
127+
inputs: th.Tensor = next(iter(transition_tensor_dataloader))
128+
inputs = inputs.to(device)
129+
# Ensure loaded data is FloatTensor and not DoubleTensor.
130+
inputs = inputs.float()
124131
else: # When using GAN.
125132
# Inputs should be some samples of input vectors? Not sure if this is the best
126133
# way to do this, there might be better options.
@@ -151,6 +158,8 @@ def interpret(
151158
rows, columns = 1, num_features
152159
if pyplot:
153160
fig = plt.figure(figsize=(columns * 2, rows * 2)) # width, height in inches
161+
else:
162+
fig = None
154163

155164
# Visualize
156165
if vis_type == "traditional":
@@ -213,17 +222,29 @@ def interpret(
213222

214223

215224
def plot_img(
216-
columns, custom_logger, feature_i, fig, img, pyplot, rows, vis_scale, wandb_logging
225+
columns: int,
226+
custom_logger: HierarchicalLogger,
227+
feature_i: int,
228+
fig: Optional[matplotlib.figure.Figure],
229+
img: np.ndarray,
230+
pyplot: bool,
231+
rows: int,
232+
vis_scale: int,
233+
wandb_logging: bool,
217234
):
218235
"""Plot the passed image to pyplot and wandb as appropriate."""
219236
_wandb_log(custom_logger, feature_i, img, vis_scale, wandb_logging)
220-
if pyplot:
237+
if fig is not None and pyplot:
221238
fig.add_subplot(rows, columns, feature_i + 1)
222239
plt.imshow(img)
223240

224241

225242
def _wandb_log(
226-
custom_logger, feature_i: int, img: np.ndarray, vis_scale: int, wandb_logging: bool
243+
custom_logger: HierarchicalLogger,
244+
feature_i: int,
245+
img: np.ndarray,
246+
vis_scale: int,
247+
wandb_logging: bool,
227248
):
228249
"""Plot to wandb if wandb logging is enabled."""
229250
if wandb_logging:

src/reward_preprocessing/scripts/train_gan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
For use in reward function feature visualization.
44
"""
55

6-
import torch as th
76
from sacred.observers import FileStorageObserver
7+
import torch as th
88

99
from reward_preprocessing.common import utils
1010
from reward_preprocessing.scripts.config.train_gan import train_gan_ex

src/reward_preprocessing/vis/reward_vis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,17 @@ def __init__(
177177

178178
self.patch_h = self.model_inputs_full.shape[2] / activations.shape[2]
179179
self.patch_w = self.model_inputs_full.shape[3] / activations.shape[3]
180+
181+
# From here on activations should be numpy array and not pytorch tensor anymore.
182+
activations = activations.detach().cpu().numpy()
183+
180184
if self.reducer is None: # No dimensionality reduction.
181185
# Activations are only used for dim reduction and to determine the shape
182186
# of the features. The former is compatible between torch and numpy (both
183187
# support .shape), so calling .numpy() is not really necessary. However,
184188
# for consistency we do it here. Consequently, self.acts_reduced is always
185189
# a numpy array.
186-
self.acts_reduced = activations.numpy()
190+
self.acts_reduced = activations
187191
self.channel_dirs = np.eye(self.acts_reduced.shape[1])
188192
self.transform = lambda acts: acts.copy()
189193
self.inverse_transform = lambda acts: acts.copy()
@@ -219,7 +223,9 @@ def __init__(
219223
)
220224
# Transform into torch tensor instead of numpy array, because this is expected
221225
# later on.
222-
self.channel_dirs = th.tensor(self.channel_dirs)
226+
self.channel_dirs = th.tensor(self.channel_dirs).to(
227+
self.model_inputs_full.device
228+
)
223229

224230
def vis_traditional(
225231
self,

0 commit comments

Comments
 (0)