Skip to content

Commit c8676f4

Browse files
committed
[BugFix] Account for terminating data in SAC losses
ghstack-source-id: dc1870292786c262b4ab6a221b3afb551e0efb9b Pull Request resolved: #2606
1 parent d90b9e3 commit c8676f4

File tree

2 files changed

+162
-8
lines changed

2 files changed

+162
-8
lines changed

test/test_cost.py

+119
Original file line numberDiff line numberDiff line change
@@ -4459,6 +4459,69 @@ def test_sac_notensordict(
44594459
assert loss_actor == loss_val_td["loss_actor"]
44604460
assert loss_alpha == loss_val_td["loss_alpha"]
44614461

4462+
@pytest.mark.parametrize("action_key", ["action", "action2"])
4463+
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
4464+
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
4465+
@pytest.mark.parametrize("done_key", ["done", "done2"])
4466+
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
4467+
def test_sac_terminating(
4468+
self, action_key, observation_key, reward_key, done_key, terminated_key, version
4469+
):
4470+
torch.manual_seed(self.seed)
4471+
td = self._create_mock_data_sac(
4472+
action_key=action_key,
4473+
observation_key=observation_key,
4474+
reward_key=reward_key,
4475+
done_key=done_key,
4476+
terminated_key=terminated_key,
4477+
)
4478+
4479+
actor = self._create_mock_actor(
4480+
observation_key=observation_key, action_key=action_key
4481+
)
4482+
qvalue = self._create_mock_qvalue(
4483+
observation_key=observation_key,
4484+
action_key=action_key,
4485+
out_keys=["state_action_value"],
4486+
)
4487+
if version == 1:
4488+
value = self._create_mock_value(observation_key=observation_key)
4489+
else:
4490+
value = None
4491+
4492+
loss = SACLoss(
4493+
actor_network=actor,
4494+
qvalue_network=qvalue,
4495+
value_network=value,
4496+
)
4497+
loss.set_keys(
4498+
action=action_key,
4499+
reward=reward_key,
4500+
done=done_key,
4501+
terminated=terminated_key,
4502+
)
4503+
4504+
torch.manual_seed(self.seed)
4505+
4506+
SoftUpdate(loss, eps=0.5)
4507+
4508+
done = td.get(("next", done_key))
4509+
while not (done.any() and not done.all()):
4510+
done.bernoulli_(0.1)
4511+
obs_nan = td.get(("next", terminated_key))
4512+
obs_nan[done.squeeze(-1)] = float("nan")
4513+
4514+
kwargs = {
4515+
action_key: td.get(action_key),
4516+
observation_key: td.get(observation_key),
4517+
f"next_{reward_key}": td.get(("next", reward_key)),
4518+
f"next_{done_key}": done,
4519+
f"next_{terminated_key}": obs_nan,
4520+
f"next_{observation_key}": td.get(("next", observation_key)),
4521+
}
4522+
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
4523+
assert loss(td).isfinite().all()
4524+
44624525
def test_state_dict(self, version):
44634526
if version == 1:
44644527
pytest.skip("Test not implemented for version 1.")
@@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict(
51125175
assert loss_actor == loss_val_td["loss_actor"]
51135176
assert loss_alpha == loss_val_td["loss_alpha"]
51145177

5178+
@pytest.mark.parametrize("action_key", ["action", "action2"])
5179+
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
5180+
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
5181+
@pytest.mark.parametrize("done_key", ["done", "done2"])
5182+
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
5183+
def test_discrete_sac_terminating(
5184+
self, action_key, observation_key, reward_key, done_key, terminated_key
5185+
):
5186+
torch.manual_seed(self.seed)
5187+
td = self._create_mock_data_sac(
5188+
action_key=action_key,
5189+
observation_key=observation_key,
5190+
reward_key=reward_key,
5191+
done_key=done_key,
5192+
terminated_key=terminated_key,
5193+
)
5194+
5195+
actor = self._create_mock_actor(
5196+
observation_key=observation_key, action_key=action_key
5197+
)
5198+
qvalue = self._create_mock_qvalue(
5199+
observation_key=observation_key,
5200+
)
5201+
5202+
loss = DiscreteSACLoss(
5203+
actor_network=actor,
5204+
qvalue_network=qvalue,
5205+
num_actions=actor.spec[action_key].space.n,
5206+
action_space="one-hot",
5207+
)
5208+
loss.set_keys(
5209+
action=action_key,
5210+
reward=reward_key,
5211+
done=done_key,
5212+
terminated=terminated_key,
5213+
)
5214+
5215+
SoftUpdate(loss, eps=0.5)
5216+
5217+
torch.manual_seed(0)
5218+
done = td.get(("next", done_key))
5219+
while not (done.any() and not done.all()):
5220+
done = done.bernoulli_(0.1)
5221+
obs_none = td.get(("next", observation_key))
5222+
obs_none[done.squeeze(-1)] = float("nan")
5223+
kwargs = {
5224+
action_key: td.get(action_key),
5225+
observation_key: td.get(observation_key),
5226+
f"next_{reward_key}": td.get(("next", reward_key)),
5227+
f"next_{done_key}": done,
5228+
f"next_{terminated_key}": td.get(("next", terminated_key)),
5229+
f"next_{observation_key}": obs_none,
5230+
}
5231+
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
5232+
assert loss(td).isfinite().all()
5233+
51155234
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
51165235
def test_discrete_sac_reduction(self, reduction):
51175236
torch.manual_seed(self.seed)

torchrl/objectives/sac.py

+43-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tensordict import TensorDict, TensorDictBase, TensorDictParams
1717

1818
from tensordict.nn import dispatch, TensorDictModule
19-
from tensordict.utils import NestedKey
19+
from tensordict.utils import expand_right, NestedKey
2020
from torch import Tensor
2121
from torchrl.data.tensor_specs import Composite, TensorSpec
2222
from torchrl.data.utils import _find_action_space
@@ -711,13 +711,37 @@ def _compute_target_v2(self, tensordict) -> Tensor:
711711
with set_exploration_type(
712712
ExplorationType.RANDOM
713713
), self.actor_network_params.to_module(self.actor_network):
714-
next_tensordict = tensordict.get("next").clone(False)
715-
next_dist = self.actor_network.get_dist(next_tensordict)
714+
next_tensordict = tensordict.get("next").copy()
715+
# Check done state and avoid passing these to the actor
716+
done = next_tensordict.get(self.tensor_keys.done)
717+
if done is not None and done.any():
718+
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
719+
else:
720+
next_tensordict_select = next_tensordict
721+
next_dist = self.actor_network.get_dist(next_tensordict_select)
716722
next_action = next_dist.rsample()
717-
next_tensordict.set(self.tensor_keys.action, next_action)
718723
next_sample_log_prob = compute_log_prob(
719724
next_dist, next_action, self.tensor_keys.log_prob
720725
)
726+
if next_tensordict_select is not next_tensordict:
727+
mask = ~done.squeeze(-1)
728+
if mask.ndim < next_action.ndim:
729+
mask = expand_right(
730+
mask, (*mask.shape, *next_action.shape[mask.ndim :])
731+
)
732+
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
733+
mask, next_action
734+
)
735+
mask = ~done.squeeze(-1)
736+
if mask.ndim < next_sample_log_prob.ndim:
737+
mask = expand_right(
738+
mask,
739+
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
740+
)
741+
next_sample_log_prob = next_sample_log_prob.new_zeros(
742+
mask.shape
743+
).masked_scatter_(mask, next_sample_log_prob)
744+
next_tensordict.set(self.tensor_keys.action, next_action)
721745

722746
# get q-values
723747
next_tensordict_expand = self._vmap_qnetworkN0(
@@ -1194,15 +1218,21 @@ def _compute_target(self, tensordict) -> Tensor:
11941218
with torch.no_grad():
11951219
next_tensordict = tensordict.get("next").clone(False)
11961220

1221+
done = next_tensordict.get(self.tensor_keys.done)
1222+
if done is not None and done.any():
1223+
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
1224+
else:
1225+
next_tensordict_select = next_tensordict
1226+
11971227
# get probs and log probs for actions computed from "next"
11981228
with self.actor_network_params.to_module(self.actor_network):
1199-
next_dist = self.actor_network.get_dist(next_tensordict)
1200-
next_prob = next_dist.probs
1201-
next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))
1229+
next_dist = self.actor_network.get_dist(next_tensordict_select)
1230+
next_log_prob = next_dist.logits
1231+
next_prob = next_log_prob.exp()
12021232

12031233
# get q-values for all actions
12041234
next_tensordict_expand = self._vmap_qnetworkN0(
1205-
next_tensordict, self.target_qvalue_network_params
1235+
next_tensordict_select, self.target_qvalue_network_params
12061236
)
12071237
next_action_value = next_tensordict_expand.get(
12081238
self.tensor_keys.action_value
@@ -1212,6 +1242,11 @@ def _compute_target(self, tensordict) -> Tensor:
12121242
next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob
12131243
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
12141244
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
1245+
if next_tensordict_select is not next_tensordict:
1246+
mask = ~done.squeeze(-1)
1247+
next_state_value = next_state_value.new_zeros(
1248+
mask.shape
1249+
).masked_scatter_(mask, next_state_value)
12151250

12161251
tensordict.set(
12171252
("next", self.value_estimator.tensor_keys.value), next_state_value

0 commit comments

Comments
 (0)