Skip to content

Commit 42ed42c

Browse files
committed
[BE] Make better logits in cost tests
ghstack-source-id: be9ea92b3f3d2592e426eaeaff7b81e50472cf16 Pull Request resolved: #2775
1 parent f6084b6 commit 42ed42c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

test/test_cost.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -4710,7 +4710,7 @@ def _create_mock_actor(
47104710
):
47114711
# Actor
47124712
action_spec = OneHot(action_dim)
4713-
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
4713+
net = nn.Linear(obs_dim, action_dim)
47144714
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
47154715
actor = ProbabilisticActor(
47164716
spec=action_spec,
@@ -11388,7 +11388,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
1138811388
action_spec = Bounded(
1138911389
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
1139011390
)
11391-
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
11391+
net = nn.Linear(obs_dim, action_dim)
1139211392
module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"])
1139311393
actor = ProbabilisticActor(
1139411394
module=module,
@@ -12632,7 +12632,7 @@ def _create_mock_actor(
1263212632
):
1263312633
# Actor
1263412634
action_spec = OneHot(action_dim)
12635-
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
12635+
net = nn.Linear(obs_dim, action_dim)
1263612636
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
1263712637
actor = ProbabilisticActor(
1263812638
spec=action_spec,
@@ -12729,8 +12729,7 @@ def _create_mock_common_layer_setup(
1272912729
common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])
1273012730
actor = ProbSeq(
1273112731
common,
12732-
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
12733-
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["logits"]),
12732+
Mod(actor_net, in_keys=["hidden"], out_keys=["logits"]),
1273412733
ProbMod(
1273512734
in_keys=["logits"],
1273612735
out_keys=["action"],

0 commit comments

Comments
 (0)