1+ import rl_games .common .wrappers as wrappers
2+ from rl_games .common .ivecenv import IVecEnv
3+
4+ # wrap your vector env so it resets for you under the hood
5+ from gymnasium import spaces
6+
7+ def remove_batch_dim (space : spaces .Space ) -> spaces .Space :
8+ """Recursively remove the first (batch) dimension from a Gym space."""
9+ if isinstance (space , spaces .Box ):
10+ # assume shape = (B, *shape); drop the 0th index
11+ low = space .low [0 ]
12+ high = space .high [0 ]
13+ return spaces .Box (low = low , high = high , dtype = space .dtype )
14+ elif isinstance (space , spaces .MultiDiscrete ):
15+ # assume nvec = (B, n); take first row
16+ nvec = space .nvec [0 ]
17+ return spaces .MultiDiscrete (nvec )
18+ elif isinstance (space , spaces .MultiBinary ):
19+ # n can be int or array-like
20+ n = space .n [0 ] if hasattr (space .n , "__len__" ) else space .n
21+ return spaces .MultiBinary (n )
22+ elif isinstance (space , spaces .Discrete ):
23+ # Discrete spaces have no extra dims
24+ return space
25+ elif isinstance (space , spaces .Tuple ):
26+ return spaces .Tuple (tuple (remove_batch_dim (s ) for s in space .spaces ))
27+ elif isinstance (space , spaces .Dict ):
28+ return spaces .Dict ({k : remove_batch_dim (s ) for k , s in space .spaces .items ()})
29+ else :
30+ raise ValueError (f"Unsupported space type: { type (space )} " )
31+
32+ class ManiskillEnv (IVecEnv ):
33+ def __init__ (self , config_name , num_actors , ** kwargs ):
34+ import gymnasium
35+ import mani_skill .envs
36+ from mani_skill .vector .wrappers .gymnasium import ManiSkillVectorEnv
37+ from mani_skill .utils .wrappers .flatten import FlattenRGBDObservationWrapper
38+ self .batch_size = num_actors
39+ env_name = kwargs .pop ('env_name' )
40+ self .seed = kwargs .pop ('seed' , 0 ) # not sure how to set this in mani_skill
41+ self .env = gymnasium .make (
42+ env_name ,
43+ num_envs = num_actors ,
44+ ** kwargs
45+ )
46+ #self.env = FlattenRGBDObservationWrapper(self.env, rgb=True, depth=False, state=False, sep_depth=False)
47+ # need to use this wrapper to have automatic reset for done envs
48+ self .env = ManiSkillVectorEnv (self .env )
49+
50+ print (f"ManiSkill env: { env_name } with { num_actors } actors" )
51+ print (f"Original observation space: { self .env .observation_space } " )
52+ self .action_space = wrappers .OldGymWrapper .convert_space (remove_batch_dim (self .env .action_space ))
53+ self .observation_space = wrappers .OldGymWrapper .convert_space (remove_batch_dim (self .env .observation_space ))
54+ print (f"Converted action space: { self .action_space } " )
55+ print (f"Converted observation space: { self .observation_space } " )
56+
57+
58+
59+
60+ def step (self , action ):
61+ next_obs , reward , done , truncated , info = self .env .step (action )
62+ is_done = done | truncated
63+ info ['time_outs' ] = truncated
64+ return next_obs , reward , is_done , info
65+
66+ def reset (self ):
67+ obs , _ = self .env .reset ()
68+ return obs
69+
70+ def get_number_of_agents (self ):
71+ return 1
72+
73+ def get_env_info (self ):
74+ info = {}
75+ info ['action_space' ] = self .action_space
76+ info ['observation_space' ] = self .observation_space
77+ return info
78+
79+
80+ def create_maniskill_env (config_name , num_actors , ** kwargs ):
81+ return ManiskillEnv (config_name , num_actors , ** kwargs )
0 commit comments