@@ -4459,6 +4459,69 @@ def test_sac_notensordict(
4459
4459
assert loss_actor == loss_val_td["loss_actor"]
4460
4460
assert loss_alpha == loss_val_td["loss_alpha"]
4461
4461
4462
+ @pytest.mark.parametrize("action_key", ["action", "action2"])
4463
+ @pytest.mark.parametrize("observation_key", ["observation", "observation2"])
4464
+ @pytest.mark.parametrize("reward_key", ["reward", "reward2"])
4465
+ @pytest.mark.parametrize("done_key", ["done", "done2"])
4466
+ @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
4467
+ def test_sac_terminating(
4468
+ self, action_key, observation_key, reward_key, done_key, terminated_key, version
4469
+ ):
4470
+ torch.manual_seed(self.seed)
4471
+ td = self._create_mock_data_sac(
4472
+ action_key=action_key,
4473
+ observation_key=observation_key,
4474
+ reward_key=reward_key,
4475
+ done_key=done_key,
4476
+ terminated_key=terminated_key,
4477
+ )
4478
+
4479
+ actor = self._create_mock_actor(
4480
+ observation_key=observation_key, action_key=action_key
4481
+ )
4482
+ qvalue = self._create_mock_qvalue(
4483
+ observation_key=observation_key,
4484
+ action_key=action_key,
4485
+ out_keys=["state_action_value"],
4486
+ )
4487
+ if version == 1:
4488
+ value = self._create_mock_value(observation_key=observation_key)
4489
+ else:
4490
+ value = None
4491
+
4492
+ loss = SACLoss(
4493
+ actor_network=actor,
4494
+ qvalue_network=qvalue,
4495
+ value_network=value,
4496
+ )
4497
+ loss.set_keys(
4498
+ action=action_key,
4499
+ reward=reward_key,
4500
+ done=done_key,
4501
+ terminated=terminated_key,
4502
+ )
4503
+
4504
+ torch.manual_seed(self.seed)
4505
+
4506
+ SoftUpdate(loss, eps=0.5)
4507
+
4508
+ done = td.get(("next", done_key))
4509
+ while not (done.any() and not done.all()):
4510
+ done.bernoulli_(0.1)
4511
+ obs_nan = td.get(("next", terminated_key))
4512
+ obs_nan[done.squeeze(-1)] = float("nan")
4513
+
4514
+ kwargs = {
4515
+ action_key: td.get(action_key),
4516
+ observation_key: td.get(observation_key),
4517
+ f"next_{reward_key}": td.get(("next", reward_key)),
4518
+ f"next_{done_key}": done,
4519
+ f"next_{terminated_key}": obs_nan,
4520
+ f"next_{observation_key}": td.get(("next", observation_key)),
4521
+ }
4522
+ td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
4523
+ assert loss(td).isfinite().all()
4524
+
4462
4525
def test_state_dict(self, version):
4463
4526
if version == 1:
4464
4527
pytest.skip("Test not implemented for version 1.")
@@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict(
5112
5175
assert loss_actor == loss_val_td["loss_actor"]
5113
5176
assert loss_alpha == loss_val_td["loss_alpha"]
5114
5177
5178
+ @pytest.mark.parametrize("action_key", ["action", "action2"])
5179
+ @pytest.mark.parametrize("observation_key", ["observation", "observation2"])
5180
+ @pytest.mark.parametrize("reward_key", ["reward", "reward2"])
5181
+ @pytest.mark.parametrize("done_key", ["done", "done2"])
5182
+ @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
5183
+ def test_discrete_sac_terminating(
5184
+ self, action_key, observation_key, reward_key, done_key, terminated_key
5185
+ ):
5186
+ torch.manual_seed(self.seed)
5187
+ td = self._create_mock_data_sac(
5188
+ action_key=action_key,
5189
+ observation_key=observation_key,
5190
+ reward_key=reward_key,
5191
+ done_key=done_key,
5192
+ terminated_key=terminated_key,
5193
+ )
5194
+
5195
+ actor = self._create_mock_actor(
5196
+ observation_key=observation_key, action_key=action_key
5197
+ )
5198
+ qvalue = self._create_mock_qvalue(
5199
+ observation_key=observation_key,
5200
+ )
5201
+
5202
+ loss = DiscreteSACLoss(
5203
+ actor_network=actor,
5204
+ qvalue_network=qvalue,
5205
+ num_actions=actor.spec[action_key].space.n,
5206
+ action_space="one-hot",
5207
+ )
5208
+ loss.set_keys(
5209
+ action=action_key,
5210
+ reward=reward_key,
5211
+ done=done_key,
5212
+ terminated=terminated_key,
5213
+ )
5214
+
5215
+ SoftUpdate(loss, eps=0.5)
5216
+
5217
+ torch.manual_seed(0)
5218
+ done = td.get(("next", done_key))
5219
+ while not (done.any() and not done.all()):
5220
+ done = done.bernoulli_(0.1)
5221
+ obs_none = td.get(("next", observation_key))
5222
+ obs_none[done.squeeze(-1)] = float("nan")
5223
+ kwargs = {
5224
+ action_key: td.get(action_key),
5225
+ observation_key: td.get(observation_key),
5226
+ f"next_{reward_key}": td.get(("next", reward_key)),
5227
+ f"next_{done_key}": done,
5228
+ f"next_{terminated_key}": td.get(("next", terminated_key)),
5229
+ f"next_{observation_key}": obs_none,
5230
+ }
5231
+ td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
5232
+ assert loss(td).isfinite().all()
5233
+
5115
5234
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
5116
5235
def test_discrete_sac_reduction(self, reduction):
5117
5236
torch.manual_seed(self.seed)
0 commit comments