Skip to content

Commit 0e6dfd3

Browse files
author
jayyoung0802
committed
feature(yzj): ptz simple mz cfg is ready and add ptz simple ez cfg
1 parent 08ef2ba commit 0e6dfd3

File tree

10 files changed

+209
-50
lines changed

10 files changed

+209
-50
lines changed

lzero/mcts/buffer/game_buffer_efficientzero.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from lzero.mcts.utils import prepare_observation
1010
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
1111
from .game_buffer_muzero import MuZeroGameBuffer
12+
from ding.torch_utils import to_device, to_tensor
13+
from ding.utils.data import default_collate
1214

1315

1416
@BUFFER_REGISTRY.register('game_buffer_efficientzero')
@@ -100,7 +102,15 @@ def _prepare_reward_value_context(
100102
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
101103
td_steps_list, action_mask_segment, to_play_segment
102104
"""
103-
zero_obs = game_segment_list[0].zero_obs()
105+
# zero_obs = game_segment_list[0].zero_obs()
106+
# zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32),
107+
# 'global_state': np.zeros((84,), dtype=np.float32),
108+
# 'agent_alone_state': np.zeros((3, 14), dtype=np.float32),
109+
# 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}])
110+
zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32),
111+
'global_state': np.zeros((14, ), dtype=np.float32),
112+
'agent_alone_state': np.zeros((1, 12), dtype=np.float32),
113+
'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}])
104114
value_obs_list = []
105115
# the value is valid or not (out of trajectory)
106116
value_mask = []
@@ -152,7 +162,7 @@ def _prepare_reward_value_context(
152162
value_mask.append(0)
153163
obs = zero_obs
154164

155-
value_obs_list.append(obs)
165+
value_obs_list.append(obs.tolist())
156166

157167
reward_value_context = [
158168
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
@@ -196,7 +206,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
196206
beg_index = self._cfg.mini_infer_size * i
197207
end_index = self._cfg.mini_infer_size * (i + 1)
198208

199-
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
209+
if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']:
210+
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
211+
elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure':
212+
m_obs = value_obs_list[beg_index:end_index]
213+
m_obs = sum(m_obs, [])
214+
m_obs = default_collate(m_obs)
215+
m_obs = to_device(m_obs, self._cfg.device)
200216

201217
# calculate the target value
202218
m_output = model.initial_inference(m_obs)

lzero/mcts/buffer/game_buffer_muzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,11 @@ def _prepare_reward_value_context(
202202
"""
203203
zero_obs = game_segment_list[0].zero_obs()
204204
# zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32),
205-
# 'global_state': np.zeros((30,), dtype=np.float32),
205+
# 'global_state': np.zeros((84,), dtype=np.float32),
206206
# 'agent_alone_state': np.zeros((3, 14), dtype=np.float32),
207207
# 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}])
208208
zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32),
209-
'global_state': np.zeros((1, 14), dtype=np.float32),
209+
'global_state': np.zeros((14, ), dtype=np.float32),
210210
'agent_alone_state': np.zeros((1, 12), dtype=np.float32),
211211
'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}])
212212
value_obs_list = []

lzero/model/efficientzero_model_mlp.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP
1010
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean
11+
from ding.utils.default_helper import get_shape0
1112

1213

1314
@MODEL_REGISTRY.register('EfficientZeroModelMLP')
@@ -36,6 +37,9 @@ def __init__(
3637
norm_type: Optional[str] = 'BN',
3738
discrete_action_encoding_type: str = 'one_hot',
3839
res_connection_in_dynamics: bool = False,
40+
state_encoder=None,
41+
state_prediction=None,
42+
state_dynamics=None,
3943
*args,
4044
**kwargs,
4145
):
@@ -104,31 +108,40 @@ def __init__(
104108
self.state_norm = state_norm
105109
self.res_connection_in_dynamics = res_connection_in_dynamics
106110

107-
self.representation_network = RepresentationNetworkMLP(
108-
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type
109-
)
110-
111-
self.dynamics_network = DynamicsNetworkMLP(
112-
action_encoding_dim=self.action_encoding_dim,
113-
num_channels=latent_state_dim + self.action_encoding_dim,
114-
common_layer_num=2,
115-
lstm_hidden_size=lstm_hidden_size,
116-
fc_reward_layers=fc_reward_layers,
117-
output_support_size=self.reward_support_size,
118-
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
119-
norm_type=norm_type,
120-
res_connection_in_dynamics=self.res_connection_in_dynamics,
121-
)
122-
123-
self.prediction_network = PredictionNetworkMLP(
124-
action_space_size=action_space_size,
125-
num_channels=latent_state_dim,
126-
fc_value_layers=fc_value_layers,
127-
fc_policy_layers=fc_policy_layers,
128-
output_support_size=self.value_support_size,
129-
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
130-
norm_type=norm_type
131-
)
111+
if state_encoder == None:
112+
self.representation_network = RepresentationNetworkMLP(
113+
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type
114+
)
115+
else:
116+
self.representation_network = state_encoder
117+
118+
if state_dynamics == None:
119+
self.dynamics_network = DynamicsNetworkMLP(
120+
action_encoding_dim=self.action_encoding_dim,
121+
num_channels=latent_state_dim + self.action_encoding_dim,
122+
common_layer_num=2,
123+
lstm_hidden_size=lstm_hidden_size,
124+
fc_reward_layers=fc_reward_layers,
125+
output_support_size=self.reward_support_size,
126+
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
127+
norm_type=norm_type,
128+
res_connection_in_dynamics=self.res_connection_in_dynamics,
129+
)
130+
else:
131+
self.dynamics_network = state_dynamics
132+
133+
if state_prediction == None:
134+
self.prediction_network = PredictionNetworkMLP(
135+
action_space_size=action_space_size,
136+
num_channels=latent_state_dim,
137+
fc_value_layers=fc_value_layers,
138+
fc_policy_layers=fc_policy_layers,
139+
output_support_size=self.value_support_size,
140+
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
141+
norm_type=norm_type
142+
)
143+
else:
144+
self.prediction_network = state_prediction
132145

133146
if self.self_supervised_learning_loss:
134147
# self_supervised_learning_loss related network proposed in EfficientZero
@@ -171,15 +184,16 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput:
171184
- latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state.
172185
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size.
173186
"""
174-
batch_size = obs.size(0)
187+
batch_size = get_shape0(obs)
175188
latent_state = self._representation(obs)
189+
device = latent_state.device
176190
policy_logits, value = self._prediction(latent_state)
177191
# zero initialization for reward hidden states
178192
# (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size)
179193
reward_hidden_state = (
180194
torch.zeros(1, batch_size,
181-
self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size,
182-
self.lstm_hidden_size).to(obs.device)
195+
self.lstm_hidden_size).to(device), torch.zeros(1, batch_size,
196+
self.lstm_hidden_size).to(device)
183197
)
184198
return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state)
185199

lzero/policy/efficientzero.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
prepare_obs, \
1919
configure_optimizers
2020
from lzero.policy.muzero import MuZeroPolicy
21+
from ding.utils.data import default_collate
22+
from ding.torch_utils import to_device, to_tensor
2123

2224

2325
@POLICY_REGISTRY.register('efficientzero')
@@ -307,7 +309,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
307309

308310
target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1)
309311
target_value = target_value.view(self._cfg.batch_size, -1)
310-
assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0)
312+
# assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0)
311313

312314
# ``scalar_transform`` to transform the original value to the scaled value,
313315
# i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
@@ -562,7 +564,13 @@ def _forward_collect(
562564
self._collect_model.eval()
563565
self._collect_mcts_temperature = temperature
564566
self.collect_epsilon = epsilon
565-
active_collect_env_num = data.shape[0]
567+
active_collect_env_num = len(data)
568+
#
569+
data = sum(data, [])
570+
data = default_collate(data)
571+
data = to_device(data, self._device)
572+
to_play = np.array(to_play).reshape(-1).tolist()
573+
566574
with torch.no_grad():
567575
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
568576
network_output = self._collect_model.initial_inference(data)
@@ -667,7 +675,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read
667675
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``.
668676
"""
669677
self._eval_model.eval()
670-
active_eval_env_num = data.shape[0]
678+
active_eval_env_num = len(data)
679+
#
680+
data = sum(data, [])
681+
data = default_collate(data)
682+
data = to_device(data, self._device)
683+
to_play = np.array(to_play).reshape(-1).tolist()
671684
with torch.no_grad():
672685
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
673686
network_output = self._eval_model.initial_inference(data)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from easydict import EasyDict
2+
3+
env_name = 'ptz_simple'
4+
multi_agent = True
5+
6+
# ==============================================================
7+
# begin of the most frequently changed config specified by the user
8+
# ==============================================================
9+
seed = 0
10+
n_agent = 1
11+
n_landmark = n_agent
12+
collector_env_num = 8
13+
evaluator_env_num = 8
14+
n_episode = 8
15+
batch_size = 256
16+
num_simulations = 50
17+
update_per_collect = 50
18+
reanalyze_ratio = 0.
19+
action_space_size = 5
20+
eps_greedy_exploration_in_collect = True
21+
# ==============================================================
22+
# end of the most frequently changed config specified by the user
23+
# ==============================================================
24+
25+
main_config = dict(
26+
exp_name=
27+
f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
28+
env=dict(
29+
env_family='mpe',
30+
env_id='simple_v2',
31+
n_agent=n_agent,
32+
n_landmark=n_landmark,
33+
max_cycles=25,
34+
agent_obs_only=False,
35+
agent_specific_global_state=True,
36+
continuous_actions=False,
37+
stop_value=0,
38+
collector_env_num=collector_env_num,
39+
evaluator_env_num=evaluator_env_num,
40+
n_evaluator_episode=evaluator_env_num,
41+
manager=dict(shared_memory=False, ),
42+
),
43+
policy=dict(
44+
multi_agent=multi_agent,
45+
ignore_done=False,
46+
model=dict(
47+
model_type='structure',
48+
latent_state_dim=256,
49+
action_space='discrete',
50+
action_space_size=action_space_size,
51+
agent_num=n_agent,
52+
self_supervised_learning_loss=False, # default is False
53+
agent_obs_shape=6,
54+
global_obs_shape=14,
55+
discrete_action_encoding_type='one_hot',
56+
global_cooperation=True, # TODO: doesn't work now
57+
hidden_size_list=[256, 256],
58+
norm_type='BN',
59+
),
60+
cuda=True,
61+
mcts_ctree=True,
62+
gumbel_algo=False,
63+
env_type='not_board_games',
64+
game_segment_length=30,
65+
random_collect_episode_num=0,
66+
eps=dict(
67+
eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
68+
type='linear',
69+
start=1.,
70+
end=0.05,
71+
decay=int(2e5),
72+
),
73+
use_augmentation=False,
74+
update_per_collect=update_per_collect,
75+
batch_size=batch_size,
76+
optim_type='SGD',
77+
lr_piecewise_constant_decay=True,
78+
learning_rate=0.2,
79+
ssl_loss_weight=0, # NOTE: default is 0.
80+
num_simulations=num_simulations,
81+
reanalyze_ratio=reanalyze_ratio,
82+
n_episode=n_episode,
83+
eval_freq=int(2e3),
84+
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
85+
collector_env_num=collector_env_num,
86+
evaluator_env_num=evaluator_env_num,
87+
),
88+
learn=dict(learner=dict(
89+
log_policy=True,
90+
hook=dict(log_show_after_iter=10, ),
91+
), ),
92+
)
93+
main_config = EasyDict(main_config)
94+
create_config = dict(
95+
env=dict(
96+
import_names=['zoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
97+
type='petting_zoo',
98+
),
99+
env_manager=dict(type='subprocess'),
100+
policy=dict(
101+
type='efficientzero',
102+
import_names=['lzero.policy.efficientzero'],
103+
),
104+
collector=dict(
105+
type='episode_muzero',
106+
import_names=['lzero.worker.muzero_collector'],
107+
)
108+
)
109+
create_config = EasyDict(create_config)
110+
ptz_simple_spread_efficientzero_config = main_config
111+
ptz_simple_spread_efficientzero_create_config = create_config
112+
113+
if __name__ == "__main__":
114+
from zoo.petting_zoo.entry import train_muzero
115+
train_muzero([main_config, create_config], seed=seed)

zoo/petting_zoo/config/ptz_simple_mz_config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
model=dict(
4747
model_type='structure',
4848
latent_state_dim=256,
49-
frame_stack_num=1,
5049
action_space='discrete',
5150
action_space_size=action_space_size,
5251
agent_num=n_agent,
@@ -69,7 +68,7 @@
6968
type='linear',
7069
start=1.,
7170
end=0.05,
72-
decay=int(1e5),
71+
decay=int(2e5),
7372
),
7473
use_augmentation=False,
7574
update_per_collect=update_per_collect,
@@ -111,6 +110,6 @@
111110
ptz_simple_spread_muzero_config = main_config
112111
ptz_simple_spread_muzero_create_config = create_config
113112

114-
if __name__ == '__main__':
113+
if __name__ == "__main__":
115114
from zoo.petting_zoo.entry import train_muzero
116115
train_muzero([main_config, create_config], seed=seed)

zoo/petting_zoo/config/ptz_simple_spread_mz_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
evaluator_env_num = 8
1414
n_episode = 8
1515
batch_size = 256
16-
num_simulations = 50
16+
num_simulations = 200
1717
update_per_collect = 50
1818
reanalyze_ratio = 0.
1919
action_space_size = 5*5*5
@@ -32,7 +32,7 @@
3232
n_landmark=n_landmark,
3333
max_cycles=25,
3434
agent_obs_only=False,
35-
agent_specific_global_state=False,
35+
agent_specific_global_state=True,
3636
continuous_actions=False,
3737
stop_value=0,
3838
collector_env_num=collector_env_num,
@@ -52,7 +52,7 @@
5252
agent_num=n_agent,
5353
self_supervised_learning_loss=False, # default is False
5454
agent_obs_shape=18,
55-
global_obs_shape=30,
55+
global_obs_shape=18*n_agent+30, # 84
5656
discrete_action_encoding_type='one_hot',
5757
global_cooperation=True, # TODO: doesn't work now
5858
hidden_size_list=[256, 256],
@@ -69,7 +69,7 @@
6969
type='linear',
7070
start=1.,
7171
end=0.05,
72-
decay=int(1e5),
72+
decay=int(2e5),
7373
),
7474
use_augmentation=False,
7575
update_per_collect=update_per_collect,

zoo/petting_zoo/entry/train_muzero.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def train_muzero(
5656
from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder
5757
elif create_cfg.policy.type == 'efficientzero':
5858
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
59+
from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder
5960
elif create_cfg.policy.type == 'sampled_efficientzero':
6061
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
6162
elif create_cfg.policy.type == 'gumbel_muzero':

0 commit comments

Comments
 (0)