Skip to content

Commit 481bc77

Browse files
committed
Harmonize reward net naming conventions
1 parent c629cd8 commit 481bc77

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,17 @@ class RewardGeneratorCombo(nn.Module):
221221
Assumes that the RewardNet normalizes observations to [0,1].
222222
"""
223223

224-
def __init__(self, reward_net: RewardNet, generator: nn.Module):
224+
def __init__(self, rew_net: RewardNet, generator: nn.Module):
225225
super().__init__()
226-
self.reward_net = reward_net
226+
self.rew_net = rew_net
227227
self.generator = generator
228228

229229
def forward(self, latent_tens: th.Tensor):
230230
latent_vec = th.mean(latent_tens, dim=[2, 3])
231231
transition_tensor = self.generator(latent_vec)
232232
obs, action_vec, next_obs = tensor_to_transition(transition_tensor)
233233
done = th.zeros(action_vec.shape)
234-
return self.reward_net.forward(obs, action_vec, next_obs, done)
234+
return self.rew_net.forward(obs, action_vec, next_obs, done)
235235

236236

237237
def log_img_wandb(

0 commit comments

Comments
 (0)