Skip to content

Commit d4e4019

Browse files
committed
[Test] Add tests for CatFrames with PermuteTransform
ghstack-source-id: e554d1c Pull Request resolved: #2715
1 parent 80690d2 commit d4e4019

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

test/test_transforms.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys):
689689
)
690690
check_env_specs(env)
691691

692+
@pytest.mark.parametrize("cat_dim", [-1, -2, -3])
693+
@pytest.mark.parametrize("cat_N", [3, 10])
694+
@pytest.mark.parametrize("device", get_default_devices())
695+
def test_with_permute_no_env(self, cat_dim, cat_N, device):
696+
torch.manual_seed(cat_dim * cat_N)
697+
pixels = torch.randn(8, 5, 3, 10, 4, device=device)
698+
699+
a = TensorDict(
700+
{
701+
"pixels": pixels,
702+
},
703+
[
704+
pixels.shape[0],
705+
],
706+
device=device,
707+
)
708+
709+
t0 = Compose(
710+
CatFrames(N=cat_N, dim=cat_dim),
711+
)
712+
713+
def get_rand_perm(ndim):
714+
cat_dim_perm = cat_dim
715+
# Ensure that the permutation moves the cat_dim
716+
while cat_dim_perm == cat_dim:
717+
perm_pos = torch.randperm(ndim)
718+
perm = perm_pos - ndim
719+
cat_dim_perm = (perm == cat_dim).nonzero().item() - ndim
720+
perm_inv = perm_pos.argsort() - ndim
721+
return perm.tolist(), perm_inv.tolist(), cat_dim_perm
722+
723+
perm, perm_inv, cat_dim_perm = get_rand_perm(pixels.dim() - 1)
724+
725+
t1 = Compose(
726+
PermuteTransform(perm, in_keys=["pixels"]),
727+
CatFrames(N=cat_N, dim=cat_dim_perm),
728+
PermuteTransform(perm_inv, in_keys=["pixels"]),
729+
)
730+
731+
b = t0._call(a.clone())
732+
c = t1._call(a.clone())
733+
assert (b == c).all()
734+
735+
@pytest.mark.skipif(not _has_gym, reason="Test executed on gym")
736+
@pytest.mark.parametrize("cat_dim", [-1, -2])
737+
def test_with_permute_env(self, cat_dim):
738+
env0 = TransformedEnv(
739+
GymEnv("Pendulum-v1"),
740+
Compose(
741+
UnsqueezeTransform(-1, in_keys=["observation"]),
742+
CatFrames(N=4, dim=cat_dim, in_keys=["observation"]),
743+
),
744+
)
745+
746+
env1 = TransformedEnv(
747+
GymEnv("Pendulum-v1"),
748+
Compose(
749+
UnsqueezeTransform(-1, in_keys=["observation"]),
750+
PermuteTransform((-1, -2), in_keys=["observation"]),
751+
CatFrames(N=4, dim=-3 - cat_dim, in_keys=["observation"]),
752+
PermuteTransform((-1, -2), in_keys=["observation"]),
753+
),
754+
)
755+
756+
torch.manual_seed(0)
757+
env0.set_seed(0)
758+
td0 = env0.reset()
759+
760+
torch.manual_seed(0)
761+
env1.set_seed(0)
762+
td1 = env1.reset()
763+
764+
assert (td0 == td1).all()
765+
766+
td0 = env0.step(td0.update(env0.full_action_spec.rand()))
767+
td1 = env0.step(td0.update(env1.full_action_spec.rand()))
768+
769+
assert (td0 == td1).all()
770+
692771
def test_serial_trans_env_check(self):
693772
env = SerialEnv(
694773
2,

0 commit comments

Comments
 (0)