Skip to content

Commit 8cf91c3

Browse files
authored
Merge pull request #26 from HumanCompatibleAI/harmonize_naming
Harmonize reward net naming conventions
2 parents c629cd8 + ffd535a commit 8cf91c3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,17 @@ class RewardGeneratorCombo(nn.Module):
221221
Assumes that the RewardNet normalizes observations to [0,1].
222222
"""
223223

224-
def __init__(self, reward_net: RewardNet, generator: nn.Module):
224+
def __init__(self, rew_net: RewardNet, generator: nn.Module):
225225
super().__init__()
226-
self.reward_net = reward_net
226+
self.rew_net = rew_net
227227
self.generator = generator
228228

229229
def forward(self, latent_tens: th.Tensor):
230230
latent_vec = th.mean(latent_tens, dim=[2, 3])
231231
transition_tensor = self.generator(latent_vec)
232232
obs, action_vec, next_obs = tensor_to_transition(transition_tensor)
233233
done = th.zeros(action_vec.shape)
234-
return self.reward_net.forward(obs, action_vec, next_obs, done)
234+
return self.rew_net.forward(obs, action_vec, next_obs, done)
235235

236236

237237
def log_img_wandb(

src/reward_preprocessing/interpret.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def interpret(
179179
# Combine rew net with GAN.
180180
gan = th.load(gan_path, map_location=th.device(device))
181181
model_to_analyse = RewardGeneratorCombo(
182-
reward_net=rew_net, generator=gan.generator
182+
rew_net=rew_net, generator=gan.generator
183183
)
184184

185185
model_to_analyse.eval() # Eval for visualization.

0 commit comments

Comments
 (0)