182
182
# Later, we will see how the target parameters should be updated in TorchRL.
183
183
#
184
184
185
- from tensordict.nn import TensorDictModule
185
+ from tensordict.nn import TensorDictModule, TensorDictSequential
186
186
187
187
188
188
def _init(
@@ -290,12 +290,11 @@ def _loss_actor(
290
290
) -> torch.Tensor:
291
291
td_copy = tensordict.select(*self.actor_in_keys)
292
292
# Get an action from the actor network: since we made it functional, we need to pass the params
293
- td_copy = self.actor_network(td_copy, params=self.actor_network_params)
293
+ with self.actor_network_params.to_module(self.actor_network):
294
+ td_copy = self.actor_network(td_copy)
294
295
# get the value associated with that action
295
- td_copy = self.value_network(
296
- td_copy,
297
- params=self.value_network_params.detach(),
298
- )
296
+ with self.value_network_params.detach().to_module(self.value_network):
297
+ td_copy = self.value_network(td_copy)
299
298
return -td_copy.get("state_action_value")
300
299
301
300
@@ -317,7 +316,8 @@ def _loss_value(
317
316
td_copy = tensordict.clone()
318
317
319
318
# V(s, a)
320
- self.value_network(td_copy, params=self.value_network_params)
319
+ with self.value_network_params.to_module(self.value_network):
320
+ self.value_network(td_copy)
321
321
pred_val = td_copy.get("state_action_value").squeeze(-1)
322
322
323
323
# we manually reconstruct the parameters of the actor-critic, where the first
@@ -332,9 +332,8 @@ def _loss_value(
332
332
batch_size=self.target_actor_network_params.batch_size,
333
333
device=self.target_actor_network_params.device,
334
334
)
335
- target_value = self.value_estimator.value_estimate(
336
- tensordict, target_params=target_params
337
- ).squeeze(-1)
335
+ with target_params.to_module(self.actor_critic):
336
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
338
337
339
338
# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
340
339
loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)
@@ -717,7 +716,7 @@ def get_env_stats():
717
716
ActorCriticWrapper,
718
717
DdpgMlpActor,
719
718
DdpgMlpQNet,
720
- OrnsteinUhlenbeckProcessWrapper ,
719
+ OrnsteinUhlenbeckProcessModule ,
721
720
ProbabilisticActor,
722
721
TanhDelta,
723
722
ValueOperator,
@@ -776,15 +775,18 @@ def make_ddpg_actor(
776
775
# Exploration
777
776
# ~~~~~~~~~~~
778
777
#
779
- # The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper `
778
+ # The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule `
780
779
# exploration module, as suggested in the original paper.
781
780
# Let's define the number of frames before OU noise reaches its minimum value
782
781
annealing_frames = 1_000_000
783
782
784
- actor_model_explore = OrnsteinUhlenbeckProcessWrapper (
783
+ actor_model_explore = TensorDictSequential (
785
784
actor,
786
- annealing_num_steps=annealing_frames,
787
- ).to(device)
785
+ OrnsteinUhlenbeckProcessModule(
786
+ spec=actor.spec.clone(),
787
+ annealing_num_steps=annealing_frames,
788
+ ).to(device),
789
+ )
788
790
if device == torch.device("cpu"):
789
791
actor_model_explore.share_memory()
790
792
@@ -1168,7 +1170,7 @@ def ceil_div(x, y):
1168
1170
)
1169
1171
1170
1172
# update the exploration strategy
1171
- actor_model_explore.step(current_frames)
1173
+ actor_model_explore[1] .step(current_frames)
1172
1174
1173
1175
collector.shutdown()
1174
1176
del collector
0 commit comments