Skip to content

Commit 6ea3f9b

Browse files
chosenonechosenone
authored andcommitted
feature(yzj): polish ctde2-(8,3,5)
1 parent c323a44 commit 6ea3f9b

17 files changed

+941
-332
lines changed

lzero/mcts/buffer/game_buffer_efficientzero.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,15 @@ def _prepare_reward_value_context(
102102
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
103103
td_steps_list, action_mask_segment, to_play_segment
104104
"""
105-
# zero_obs = game_segment_list[0].zero_obs()
106-
# zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32),
107-
# 'global_state': np.zeros((84,), dtype=np.float32),
108-
# 'agent_alone_state': np.zeros((3, 14), dtype=np.float32),
109-
# 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}])
110-
zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32),
105+
zero_obs = game_segment_list[0].zero_obs()
106+
# zero_obs = np.array([{'agent_state': np.zeros((18,), dtype=np.float32),
107+
# 'global_state': np.zeros((48,), dtype=np.float32),
108+
# 'agent_alone_state': np.zeros((14,), dtype=np.float32),
109+
# 'agent_alone_padding_state': np.zeros((18,), dtype=np.float32),}])
110+
zero_obs = np.array([{'agent_state': np.zeros((6,), dtype=np.float32),
111111
'global_state': np.zeros((14, ), dtype=np.float32),
112-
'agent_alone_state': np.zeros((1, 12), dtype=np.float32),
113-
'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}])
112+
'agent_alone_state': np.zeros((12,), dtype=np.float32),
113+
'agent_alone_padding_state': np.zeros((12,), dtype=np.float32),}])
114114
value_obs_list = []
115115
# the value is valid or not (out of trajectory)
116116
value_mask = []
@@ -221,13 +221,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
221221
# EfficientZero related core code
222222
# ==============================================================
223223
# if not in training, obtain the scalars of the value/reward
224-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
225-
[
226-
m_output.latent_state,
227-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
228-
m_output.policy_logits
229-
]
230-
)
224+
# [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
225+
# [
226+
# m_output.latent_state,
227+
# inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
228+
# m_output.policy_logits
229+
# ]
230+
# )
231+
m_output.latent_state = (to_detach_cpu_numpy(m_output.latent_state[0]), to_detach_cpu_numpy(m_output.latent_state[1]))
232+
m_output.value = to_detach_cpu_numpy(inverse_scalar_transform(m_output.value, self._cfg.model.support_scale))
233+
m_output.policy_logits = to_detach_cpu_numpy(m_output.policy_logits)
231234
m_output.reward_hidden_state = (
232235
m_output.reward_hidden_state[0].detach().cpu().numpy(),
233236
m_output.reward_hidden_state[1].detach().cpu().numpy()

lzero/mcts/buffer/game_buffer_muzero.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,14 @@ def _prepare_reward_value_context(
201201
td_steps_list, action_mask_segment, to_play_segment
202202
"""
203203
zero_obs = game_segment_list[0].zero_obs()
204-
# zero_obs = np.array([{'agent_state': np.zeros((3, 18), dtype=np.float32),
205-
# 'global_state': np.zeros((84,), dtype=np.float32),
206-
# 'agent_alone_state': np.zeros((3, 14), dtype=np.float32),
207-
# 'agent_alone_padding_state': np.zeros((3, 18), dtype=np.float32),}])
208-
zero_obs = np.array([{'agent_state': np.zeros((1, 6), dtype=np.float32),
204+
zero_obs = np.array([{'agent_state': np.zeros((18,), dtype=np.float32),
205+
'global_state': np.zeros((48,), dtype=np.float32),
206+
'agent_alone_state': np.zeros((14,), dtype=np.float32),
207+
'agent_alone_padding_state': np.zeros((18,), dtype=np.float32),}])
208+
zero_obs = np.array([{'agent_state': np.zeros((6,), dtype=np.float32),
209209
'global_state': np.zeros((14, ), dtype=np.float32),
210-
'agent_alone_state': np.zeros((1, 12), dtype=np.float32),
211-
'agent_alone_padding_state': np.zeros((1, 12), dtype=np.float32),}])
210+
'agent_alone_state': np.zeros((12,), dtype=np.float32),
211+
'agent_alone_padding_state': np.zeros((12,), dtype=np.float32),}])
212212
value_obs_list = []
213213
# the value is valid or not (out of game_segment)
214214
value_mask = []
@@ -400,14 +400,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
400400

401401
if not model.training:
402402
# if not in training, obtain the scalars of the value/reward
403-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
404-
[
405-
m_output.latent_state,
406-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
407-
m_output.policy_logits
408-
]
409-
)
410-
403+
# [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
404+
# [
405+
# m_output.latent_state,
406+
# inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
407+
# m_output.policy_logits
408+
# ]
409+
# )
410+
m_output.latent_state = (to_detach_cpu_numpy(m_output.latent_state[0]), to_detach_cpu_numpy(m_output.latent_state[1]))
411+
m_output.value = to_detach_cpu_numpy(inverse_scalar_transform(m_output.value, self._cfg.model.support_scale))
412+
m_output.policy_logits = to_detach_cpu_numpy(m_output.policy_logits)
411413
network_output.append(m_output)
412414

413415
# concat the output slices after model inference

lzero/mcts/tree_search/mcts_ctree.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def search(
9696
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor
9797

9898
# the data storage of latent states: storing the latent state of all the nodes in one search.
99-
latent_state_batch_in_search_path = [latent_state_roots]
99+
agent_latent_state_roots, global_latent_state_roots = latent_state_roots
100+
agent_latent_state_batch_in_search_path = [agent_latent_state_roots]
101+
global_latent_state_batch_in_search_path = [global_latent_state_roots]
100102
# the data storage of value prefix hidden states in LSTM
101103
reward_hidden_state_c_batch = [reward_hidden_state_roots[0]]
102104
reward_hidden_state_h_batch = [reward_hidden_state_roots[1]]
@@ -108,7 +110,8 @@ def search(
108110
for simulation_index in range(self._cfg.num_simulations):
109111
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.
110112

111-
latent_states = []
113+
agent_latent_states = []
114+
global_latent_states = []
112115
hidden_states_c_reward = []
113116
hidden_states_h_reward = []
114117

@@ -132,11 +135,13 @@ def search(
132135

133136
# obtain the latent state for leaf node
134137
for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch):
135-
latent_states.append(latent_state_batch_in_search_path[ix][iy])
138+
agent_latent_states.append(agent_latent_state_batch_in_search_path[ix][iy])
139+
global_latent_states.append(global_latent_state_batch_in_search_path[ix][iy])
136140
hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy])
137141
hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy])
138142

139-
latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
143+
agent_latent_states = torch.from_numpy(np.asarray(agent_latent_states)).to(self._cfg.device).float()
144+
global_latent_states = torch.from_numpy(np.asarray(global_latent_states)).to(self._cfg.device).float()
140145
hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device
141146
).unsqueeze(0)
142147
hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device
@@ -151,10 +156,12 @@ def search(
151156
At the end of the simulation, the statistics along the trajectory are updated.
152157
"""
153158
network_output = model.recurrent_inference(
154-
latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions
159+
(agent_latent_states, global_latent_states), (hidden_states_c_reward, hidden_states_h_reward), last_actions
155160
)
161+
network_output_agent_latent_state, network_output_global_latent_state = network_output.latent_state
156162

157-
network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
163+
network_output_agent_latent_state = to_detach_cpu_numpy(network_output_agent_latent_state)
164+
network_output_global_latent_state = to_detach_cpu_numpy(network_output_global_latent_state)
158165
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
159166
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
160167
network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix))
@@ -164,7 +171,8 @@ def search(
164171
network_output.reward_hidden_state[1].detach().cpu().numpy()
165172
)
166173

167-
latent_state_batch_in_search_path.append(network_output.latent_state)
174+
agent_latent_state_batch_in_search_path.append(network_output_agent_latent_state)
175+
global_latent_state_batch_in_search_path.append(network_output_global_latent_state)
168176
# tolist() is to be compatible with cpp datatype.
169177
value_prefix_batch = network_output.value_prefix.reshape(-1).tolist()
170178
value_batch = network_output.value.reshape(-1).tolist()
@@ -273,7 +281,9 @@ def search(
273281
batch_size = roots.num
274282
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor
275283
# the data storage of latent states: storing the latent state of all the nodes in the search.
276-
latent_state_batch_in_search_path = [latent_state_roots]
284+
agent_latent_state_roots, global_latent_state_roots = latent_state_roots
285+
agent_latent_state_batch_in_search_path = [agent_latent_state_roots]
286+
global_latent_state_batch_in_search_path = [global_latent_state_roots]
277287

278288
# minimax value storage
279289
min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size)
@@ -282,7 +292,8 @@ def search(
282292
for simulation_index in range(self._cfg.num_simulations):
283293
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.
284294

285-
latent_states = []
295+
agent_latent_states = []
296+
global_latent_states = []
286297

287298
# prepare a result wrapper to transport results between python and c++ parts
288299
results = tree_muzero.ResultsWrapper(num=batch_size)
@@ -302,9 +313,11 @@ def search(
302313

303314
# obtain the latent state for leaf node
304315
for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch):
305-
latent_states.append(latent_state_batch_in_search_path[ix][iy])
316+
agent_latent_states.append(agent_latent_state_batch_in_search_path[ix][iy])
317+
global_latent_states.append(global_latent_state_batch_in_search_path[ix][iy])
306318

307-
latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
319+
agent_latent_states = torch.from_numpy(np.asarray(agent_latent_states)).to(self._cfg.device).float()
320+
global_latent_states = torch.from_numpy(np.asarray(global_latent_states)).to(self._cfg.device).float()
308321
# .long() is only for discrete action
309322
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()
310323
"""
@@ -314,14 +327,19 @@ def search(
314327
MCTS stage 3: Backup
315328
At the end of the simulation, the statistics along the trajectory are updated.
316329
"""
317-
network_output = model.recurrent_inference(latent_states, last_actions)
330+
network_output = model.recurrent_inference((agent_latent_states, global_latent_states), last_actions)
318331

319-
network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
332+
network_output_agent_latent_state, network_output_global_latent_state = network_output.latent_state
333+
334+
# network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
335+
network_output_agent_latent_state = to_detach_cpu_numpy(network_output_agent_latent_state)
336+
network_output_global_latent_state = to_detach_cpu_numpy(network_output_global_latent_state)
320337
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
321338
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
322339
network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward))
323340

324-
latent_state_batch_in_search_path.append(network_output.latent_state)
341+
agent_latent_state_batch_in_search_path.append(network_output_agent_latent_state)
342+
global_latent_state_batch_in_search_path.append(network_output_global_latent_state)
325343
# tolist() is to be compatible with cpp datatype.
326344
reward_batch = network_output.reward.reshape(-1).tolist()
327345
value_batch = network_output.value.reshape(-1).tolist()

lzero/model/efficientzero_model_mlp.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,17 @@ def __init__(
128128
res_connection_in_dynamics=self.res_connection_in_dynamics,
129129
)
130130
else:
131-
self.dynamics_network = state_dynamics
131+
self.dynamics_network = state_dynamics(
132+
action_encoding_dim=self.action_encoding_dim,
133+
num_channels=latent_state_dim + self.action_encoding_dim,
134+
common_layer_num=2,
135+
lstm_hidden_size=lstm_hidden_size,
136+
fc_reward_layers=fc_reward_layers,
137+
output_support_size=self.reward_support_size,
138+
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
139+
norm_type=norm_type,
140+
res_connection_in_dynamics=self.res_connection_in_dynamics,
141+
)
132142

133143
if state_prediction == None:
134144
self.prediction_network = PredictionNetworkMLP(
@@ -141,7 +151,16 @@ def __init__(
141151
norm_type=norm_type
142152
)
143153
else:
144-
self.prediction_network = state_prediction
154+
self.prediction_network = state_prediction(
155+
action_space_size=action_space_size,
156+
num_channels=latent_state_dim,
157+
fc_value_layers=fc_value_layers,
158+
fc_policy_layers=fc_policy_layers,
159+
output_support_size=self.value_support_size,
160+
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
161+
norm_type=norm_type
162+
163+
)
145164

146165
if self.self_supervised_learning_loss:
147166
# self_supervised_learning_loss related network proposed in EfficientZero
@@ -186,7 +205,7 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput:
186205
"""
187206
batch_size = get_shape0(obs)
188207
latent_state = self._representation(obs)
189-
device = latent_state.device
208+
device = latent_state[0].device
190209
policy_logits, value = self._prediction(latent_state)
191210
# zero initialization for reward hidden states
192211
# (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size)
@@ -307,19 +326,22 @@ def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple,
307326
# e.g., torch.Size([8]) -> torch.Size([8, 1])
308327
action_encoding = action_encoding.unsqueeze(-1)
309328

310-
action_encoding = action_encoding.to(latent_state.device).float()
329+
agent_latent_state, global_latent_state = latent_state
330+
action_encoding = action_encoding.to(agent_latent_state.device).float()
311331
# state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or
312332
# (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type.
313-
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1)
333+
agent_state_action_encoding = torch.cat((agent_latent_state, action_encoding), dim=1)
334+
global_state_action_encoding = torch.cat((agent_latent_state, global_latent_state, action_encoding), dim=1)
314335

315336
# NOTE: the key difference with MuZero
316-
next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network(
317-
state_action_encoding, reward_hidden_state
337+
(next_agent_latent_state, next_global_latent_state), next_reward_hidden_state, value_prefix = self.dynamics_network(
338+
(agent_state_action_encoding, global_state_action_encoding), reward_hidden_state
318339
)
319340

320341
if self.state_norm:
321-
next_latent_state = renormalize(next_latent_state)
322-
return next_latent_state, next_reward_hidden_state, value_prefix
342+
next_agent_latent_state = renormalize(next_agent_latent_state)
343+
next_global_latent_state = renormalize(next_global_latent_state)
344+
return (next_agent_latent_state, next_global_latent_state), next_reward_hidden_state, value_prefix
323345

324346
def project(self, latent_state: torch.Tensor, with_grad=True):
325347
"""

0 commit comments

Comments
 (0)