Skip to content

Commit 08ef2ba

Browse files
author
jayyoung0802
committed
feature(yzj): add ptz simple env
1 parent 6f80173 commit 08ef2ba

File tree

5 files changed

+22
-20
lines changed

5 files changed

+22
-20
lines changed

lzero/mcts/buffer/game_buffer_muzero.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ def _prepare_reward_value_context(
201201
td_steps_list, action_mask_segment, to_play_segment
202202
"""
203203
zero_obs = game_segment_list[0].zero_obs()
204-
zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32),
205-
'global_state': np.zeros((30,), dtype=np.float32),
206-
'agent_alone_state': np.zeros((3, 14), dtype=np.float32),
207-
'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}])
204+
# zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32),
205+
# 'global_state': np.zeros((30,), dtype=np.float32),
206+
# 'agent_alone_state': np.zeros((3, 14), dtype=np.float32),
207+
# 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}])
208208
zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32),
209-
'global_state': np.zeros((8,), dtype=np.float32),
209+
'global_state': np.zeros((1, 14), dtype=np.float32),
210210
'agent_alone_state': np.zeros((1, 12), dtype=np.float32),
211211
'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}])
212212
value_obs_list = []

lzero/model/muzero_model_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput:
186186
value,
187187
[0. for _ in range(batch_size)],
188188
policy_logits,
189-
latent_state[1],
189+
latent_state,
190190
)
191191

192192
def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput:
@@ -214,7 +214,7 @@ def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor)
214214
"""
215215
next_latent_state, reward = self._dynamics(latent_state, action)
216216
policy_logits, value = self._prediction(next_latent_state)
217-
return MZNetworkOutput(value, reward, policy_logits, next_latent_state[1])
217+
return MZNetworkOutput(value, reward, policy_logits, next_latent_state)
218218

219219
def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]:
220220
"""

zoo/petting_zoo/config/ptz_simple_mz_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
n_landmark=n_landmark,
3333
max_cycles=25,
3434
agent_obs_only=False,
35-
agent_specific_global_state=False,
35+
agent_specific_global_state=True,
3636
continuous_actions=False,
3737
stop_value=0,
3838
collector_env_num=collector_env_num,
@@ -52,7 +52,7 @@
5252
agent_num=n_agent,
5353
self_supervised_learning_loss=False, # default is False
5454
agent_obs_shape=6,
55-
global_obs_shape=8,
55+
global_obs_shape=14,
5656
discrete_action_encoding_type='one_hot',
5757
global_cooperation=True, # TODO: doesn't work now
5858
hidden_size_list=[256, 256],
@@ -97,7 +97,7 @@
9797
import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
9898
type='petting_zoo',
9999
),
100-
env_manager=dict(type='base'),
100+
env_manager=dict(type='subprocess'),
101101
policy=dict(
102102
type='muzero',
103103
import_names=['lzero.policy.muzero'],

zoo/petting_zoo/entry/train_muzero.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def train_muzero(
7979
evaluator_env.seed(cfg.seed, dynamic_seed=False)
8080
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
8181

82-
model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg))
82+
# model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg))
83+
model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg))
8384
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
8485

8586
# load pretrained model

zoo/petting_zoo/model/model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, cfg):
2323
norm_type='BN')
2424

2525
self.global_encoder = RepresentationNetworkMLP(observation_shape=global_obs_shape,
26-
hidden_channels=128,
26+
hidden_channels=256,
2727
norm_type='BN')
2828

2929
self.encoder = RepresentationNetworkMLP(observation_shape=128+128*self.agent_num,
@@ -32,15 +32,16 @@ def __init__(self, cfg):
3232

3333
def forward(self, x):
3434
# agent
35-
batch_size, agent_num = x['agent_state'].shape[0], x['agent_state'].shape[1]
36-
agent_state = x['agent_state'].reshape(batch_size*agent_num, -1)
37-
agent_state = self.agent_encoder(agent_state)
38-
agent_state_B = agent_state.reshape(batch_size, -1)
39-
agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1)
35+
batch_size, agent_num = x['global_state'].shape[0], x['global_state'].shape[1]
36+
latent_state = x['global_state'].reshape(batch_size*agent_num, -1)
37+
latent_state = self.global_encoder(latent_state)
38+
return latent_state
39+
# agent_state_B = agent_state.reshape(batch_size, -1)
40+
# agent_state_B_A = agent_state.reshape(batch_size, agent_num, -1)
4041
# global
41-
global_state = self.global_encoder(x['global_state'])
42-
global_state = self.encoder(torch.cat((agent_state_B, global_state),dim=1))
43-
return (agent_state_B, global_state)
42+
# global_state = self.global_encoder(x['global_state'])
43+
# global_state = self.encoder(torch.cat((agent_state_B, global_state),dim=1))
44+
# return (agent_state_B, global_state)
4445

4546

4647
class PettingZooPrediction(nn.Module):

0 commit comments

Comments
 (0)