Skip to content

Commit 40fcdb6

Browse files
committed
[Feature] VecNormV2
ghstack-source-id: 639d07f Pull Request resolved: #2867
1 parent 5d72561 commit 40fcdb6

File tree

11 files changed

+1209
-15
lines changed

11 files changed

+1209
-15
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,7 @@ to be able to create this other composition:
10831083
VIPTransform
10841084
VecGymEnvTransform
10851085
VecNorm
1086+
VecNormV2
10861087
gSDENoise
10871088

10881089
Environments with masked actions

test/mocking_classes.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,10 +1097,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
10971097
self.done_spec = Categorical(
10981098
2,
10991099
dtype=torch.bool,
1100-
shape=(
1101-
*self.batch_size,
1102-
1,
1103-
),
1100+
shape=(*self.batch_size, 1),
11041101
device=self.device,
11051102
)
11061103
self.action_spec = Binary(n=1, shape=[*self.batch_size, 1], device=self.device)
@@ -1146,7 +1143,9 @@ def _step(
11461143
"observation": self.count.clone(),
11471144
"done": self.count > self.max_steps,
11481145
"terminated": self.count > self.max_steps,
1149-
"reward": torch.zeros_like(self.count, dtype=torch.float),
1146+
"reward": torch.zeros_like(
1147+
self.count, dtype=self.full_reward_spec[self.reward_keys[0]].dtype
1148+
),
11501149
},
11511150
batch_size=self.batch_size,
11521151
device=self.device,
@@ -1300,7 +1299,11 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
13001299
source[group_name][agent_name] = TensorDict(
13011300
source={
13021301
"observation": torch.rand(
1303-
(*self.batch_size, 3, 4), device=self.device
1302+
(*self.batch_size, 3, 4),
1303+
device=self.device,
1304+
dtype=self.full_observation_spec[
1305+
group_name, agent_name, "observation"
1306+
].dtype,
13041307
),
13051308
"done": self.count > self.max_steps,
13061309
"terminated": self.count > self.max_steps,
@@ -1324,11 +1327,20 @@ def _step(
13241327
source[group_name][agent_name] = TensorDict(
13251328
source={
13261329
"observation": torch.rand(
1327-
(*self.batch_size, 3, 4), device=self.device
1330+
(*self.batch_size, 3, 4),
1331+
device=self.device,
1332+
dtype=self.full_observation_spec[
1333+
group_name, agent_name, "observation"
1334+
].dtype,
13281335
),
13291336
"done": self.count > self.max_steps,
13301337
"terminated": self.count > self.max_steps,
1331-
"reward": torch.zeros_like(self.count, dtype=torch.float),
1338+
"reward": torch.zeros_like(
1339+
self.count,
1340+
dtype=self.full_reward_spec[
1341+
group_name, agent_name, "reward"
1342+
].dtype,
1343+
),
13321344
},
13331345
batch_size=self.batch_size,
13341346
device=self.device,

test/test_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4257,7 +4257,7 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42574257
if stateful:
42584258
all_actions = env.all_actions()
42594259
else:
4260-
# Reset the the initial state first, just to make sure
4260+
# Reset theinitial state first, just to make sure
42614261
# `all_actions` knows how to get the board state from the input.
42624262
env.reset()
42634263
all_actions = env.all_actions(td.clone())

test/test_specs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,10 @@ def test_repr(self, shape, is_complete, device, dtype):
525525

526526
def test_device_cast_with_dtype_fails(self, shape, is_complete, device, dtype):
527527
ts = self._composite_spec(shape, is_complete, device, dtype)
528-
with pytest.raises(ValueError, match="Only device casting is allowed"):
529-
ts.to(torch.float16)
528+
ts = ts.to(torch.float16)
529+
for spec in ts.values(True, True):
530+
if spec is not None:
531+
assert spec.dtype == torch.float16
530532

531533
@pytest.mark.parametrize("dest", get_available_devices())
532534
def test_device_cast(self, shape, is_complete, device, dtype, dest):

0 commit comments

Comments
 (0)