Skip to content

Commit d706405

Browse files
author
Vincent Moens
committed
amend
1 parent a688b90 commit d706405

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

advanced_source/coding_ddpg.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@
182182
# Later, we will see how the target parameters should be updated in TorchRL.
183183
#
184184

185-
from tensordict.nn import TensorDictModule
185+
from tensordict.nn import TensorDictModule, TensorDictSequential
186186

187187

188188
def _init(
@@ -290,12 +290,11 @@ def _loss_actor(
290290
) -> torch.Tensor:
291291
td_copy = tensordict.select(*self.actor_in_keys)
292292
# Get an action from the actor network: since we made it functional, we need to pass the params
293-
td_copy = self.actor_network(td_copy, params=self.actor_network_params)
293+
with self.actor_network_params.to_module(self.actor_network):
294+
td_copy = self.actor_network(td_copy)
294295
# get the value associated with that action
295-
td_copy = self.value_network(
296-
td_copy,
297-
params=self.value_network_params.detach(),
298-
)
296+
with self.value_network_params.detach().to_module(self.value_network):
297+
td_copy = self.value_network(td_copy)
299298
return -td_copy.get("state_action_value")
300299

301300

@@ -317,7 +316,8 @@ def _loss_value(
317316
td_copy = tensordict.clone()
318317

319318
# V(s, a)
320-
self.value_network(td_copy, params=self.value_network_params)
319+
with self.value_network_params.to_module(self.value_network):
320+
self.value_network(td_copy)
321321
pred_val = td_copy.get("state_action_value").squeeze(-1)
322322

323323
# we manually reconstruct the parameters of the actor-critic, where the first
@@ -332,9 +332,8 @@ def _loss_value(
332332
batch_size=self.target_actor_network_params.batch_size,
333333
device=self.target_actor_network_params.device,
334334
)
335-
target_value = self.value_estimator.value_estimate(
336-
tensordict, target_params=target_params
337-
).squeeze(-1)
335+
with target_params.to_module(self.actor_critic):
336+
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
338337

339338
# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
340339
loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)
@@ -717,7 +716,7 @@ def get_env_stats():
717716
ActorCriticWrapper,
718717
DdpgMlpActor,
719718
DdpgMlpQNet,
720-
OrnsteinUhlenbeckProcessWrapper,
719+
OrnsteinUhlenbeckProcessModule,
721720
ProbabilisticActor,
722721
TanhDelta,
723722
ValueOperator,
@@ -776,15 +775,18 @@ def make_ddpg_actor(
776775
# Exploration
777776
# ~~~~~~~~~~~
778777
#
779-
# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`
778+
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
780779
# exploration module, as suggested in the original paper.
781780
# Let's define the number of frames before OU noise reaches its minimum value
782781
annealing_frames = 1_000_000
783782

784-
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
783+
actor_model_explore = TensorDictSequential(
785784
actor,
786-
annealing_num_steps=annealing_frames,
787-
).to(device)
785+
OrnsteinUhlenbeckProcessModule(
786+
spec=actor.spec.clone(),
787+
annealing_num_steps=annealing_frames,
788+
).to(device),
789+
)
788790
if device == torch.device("cpu"):
789791
actor_model_explore.share_memory()
790792

@@ -1168,7 +1170,7 @@ def ceil_div(x, y):
11681170
)
11691171

11701172
# update the exploration strategy
1171-
actor_model_explore.step(current_frames)
1173+
actor_model_explore[1].step(current_frames)
11721174

11731175
collector.shutdown()
11741176
del collector

0 commit comments

Comments
 (0)