File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -221,17 +221,17 @@ class RewardGeneratorCombo(nn.Module):
221
221
Assumes that the RewardNet normalizes observations to [0,1].
222
222
"""
223
223
224
- def __init__ (self , reward_net : RewardNet , generator : nn .Module ):
224
+ def __init__ (self , rew_net : RewardNet , generator : nn .Module ):
225
225
super ().__init__ ()
226
- self .reward_net = reward_net
226
+ self .rew_net = rew_net
227
227
self .generator = generator
228
228
229
229
def forward (self , latent_tens : th .Tensor ):
230
230
latent_vec = th .mean (latent_tens , dim = [2 , 3 ])
231
231
transition_tensor = self .generator (latent_vec )
232
232
obs , action_vec , next_obs = tensor_to_transition (transition_tensor )
233
233
done = th .zeros (action_vec .shape )
234
- return self .reward_net .forward (obs , action_vec , next_obs , done )
234
+ return self .rew_net .forward (obs , action_vec , next_obs , done )
235
235
236
236
237
237
def log_img_wandb (
Original file line number Diff line number Diff line change @@ -179,7 +179,7 @@ def interpret(
179
179
# Combine rew net with GAN.
180
180
gan = th .load (gan_path , map_location = th .device (device ))
181
181
model_to_analyse = RewardGeneratorCombo (
182
- reward_net = rew_net , generator = gan .generator
182
+ rew_net = rew_net , generator = gan .generator
183
183
)
184
184
185
185
model_to_analyse .eval () # Eval for visualization.
You can’t perform that action at this time.
0 commit comments