@@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys):
689
689
)
690
690
check_env_specs (env )
691
691
692
+ @pytest .mark .parametrize ("cat_dim" , [- 1 , - 2 , - 3 ])
693
+ @pytest .mark .parametrize ("cat_N" , [3 , 10 ])
694
+ @pytest .mark .parametrize ("device" , get_default_devices ())
695
+ def test_with_permute_no_env (self , cat_dim , cat_N , device ):
696
+ torch .manual_seed (cat_dim * cat_N )
697
+ pixels = torch .randn (8 , 5 , 3 , 10 , 4 , device = device )
698
+
699
+ a = TensorDict (
700
+ {
701
+ "pixels" : pixels ,
702
+ },
703
+ [
704
+ pixels .shape [0 ],
705
+ ],
706
+ device = device ,
707
+ )
708
+
709
+ t0 = Compose (
710
+ CatFrames (N = cat_N , dim = cat_dim ),
711
+ )
712
+
713
+ def get_rand_perm (ndim ):
714
+ cat_dim_perm = cat_dim
715
+ # Ensure that the permutation moves the cat_dim
716
+ while cat_dim_perm == cat_dim :
717
+ perm_pos = torch .randperm (ndim )
718
+ perm = perm_pos - ndim
719
+ cat_dim_perm = (perm == cat_dim ).nonzero ().item () - ndim
720
+ perm_inv = perm_pos .argsort () - ndim
721
+ return perm .tolist (), perm_inv .tolist (), cat_dim_perm
722
+
723
+ perm , perm_inv , cat_dim_perm = get_rand_perm (pixels .dim () - 1 )
724
+
725
+ t1 = Compose (
726
+ PermuteTransform (perm , in_keys = ["pixels" ]),
727
+ CatFrames (N = cat_N , dim = cat_dim_perm ),
728
+ PermuteTransform (perm_inv , in_keys = ["pixels" ]),
729
+ )
730
+
731
+ b = t0 ._call (a .clone ())
732
+ c = t1 ._call (a .clone ())
733
+ assert (b == c ).all ()
734
+
735
+ @pytest .mark .skipif (not _has_gym , reason = "Test executed on gym" )
736
+ @pytest .mark .parametrize ("cat_dim" , [- 1 , - 2 ])
737
+ def test_with_permute_env (self , cat_dim ):
738
+ env0 = TransformedEnv (
739
+ GymEnv ("Pendulum-v1" ),
740
+ Compose (
741
+ UnsqueezeTransform (- 1 , in_keys = ["observation" ]),
742
+ CatFrames (N = 4 , dim = cat_dim , in_keys = ["observation" ]),
743
+ ),
744
+ )
745
+
746
+ env1 = TransformedEnv (
747
+ GymEnv ("Pendulum-v1" ),
748
+ Compose (
749
+ UnsqueezeTransform (- 1 , in_keys = ["observation" ]),
750
+ PermuteTransform ((- 1 , - 2 ), in_keys = ["observation" ]),
751
+ CatFrames (N = 4 , dim = - 3 - cat_dim , in_keys = ["observation" ]),
752
+ PermuteTransform ((- 1 , - 2 ), in_keys = ["observation" ]),
753
+ ),
754
+ )
755
+
756
+ torch .manual_seed (0 )
757
+ env0 .set_seed (0 )
758
+ td0 = env0 .reset ()
759
+
760
+ torch .manual_seed (0 )
761
+ env1 .set_seed (0 )
762
+ td1 = env1 .reset ()
763
+
764
+ assert (td0 == td1 ).all ()
765
+
766
+ td0 = env0 .step (td0 .update (env0 .full_action_spec .rand ()))
767
+ td1 = env0 .step (td0 .update (env1 .full_action_spec .rand ()))
768
+
769
+ assert (td0 == td1 ).all ()
770
+
692
771
def test_serial_trans_env_check (self ):
693
772
env = SerialEnv (
694
773
2 ,
0 commit comments