@@ -96,7 +96,9 @@ def search(
96
96
pb_c_base , pb_c_init , discount_factor = self ._cfg .pb_c_base , self ._cfg .pb_c_init , self ._cfg .discount_factor
97
97
98
98
# the data storage of latent states: storing the latent state of all the nodes in one search.
99
- latent_state_batch_in_search_path = [latent_state_roots ]
99
+ agent_latent_state_roots , global_latent_state_roots = latent_state_roots
100
+ agent_latent_state_batch_in_search_path = [agent_latent_state_roots ]
101
+ global_latent_state_batch_in_search_path = [global_latent_state_roots ]
100
102
# the data storage of value prefix hidden states in LSTM
101
103
reward_hidden_state_c_batch = [reward_hidden_state_roots [0 ]]
102
104
reward_hidden_state_h_batch = [reward_hidden_state_roots [1 ]]
@@ -108,7 +110,8 @@ def search(
108
110
for simulation_index in range (self ._cfg .num_simulations ):
109
111
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.
110
112
111
- latent_states = []
113
+ agent_latent_states = []
114
+ global_latent_states = []
112
115
hidden_states_c_reward = []
113
116
hidden_states_h_reward = []
114
117
@@ -132,11 +135,13 @@ def search(
132
135
133
136
# obtain the latent state for leaf node
134
137
for ix , iy in zip (latent_state_index_in_search_path , latent_state_index_in_batch ):
135
- latent_states .append (latent_state_batch_in_search_path [ix ][iy ])
138
+ agent_latent_states .append (agent_latent_state_batch_in_search_path [ix ][iy ])
139
+ global_latent_states .append (global_latent_state_batch_in_search_path [ix ][iy ])
136
140
hidden_states_c_reward .append (reward_hidden_state_c_batch [ix ][0 ][iy ])
137
141
hidden_states_h_reward .append (reward_hidden_state_h_batch [ix ][0 ][iy ])
138
142
139
- latent_states = torch .from_numpy (np .asarray (latent_states )).to (self ._cfg .device ).float ()
143
+ agent_latent_states = torch .from_numpy (np .asarray (agent_latent_states )).to (self ._cfg .device ).float ()
144
+ global_latent_states = torch .from_numpy (np .asarray (global_latent_states )).to (self ._cfg .device ).float ()
140
145
hidden_states_c_reward = torch .from_numpy (np .asarray (hidden_states_c_reward )).to (self ._cfg .device
141
146
).unsqueeze (0 )
142
147
hidden_states_h_reward = torch .from_numpy (np .asarray (hidden_states_h_reward )).to (self ._cfg .device
@@ -151,10 +156,12 @@ def search(
151
156
At the end of the simulation, the statistics along the trajectory are updated.
152
157
"""
153
158
network_output = model .recurrent_inference (
154
- latent_states , (hidden_states_c_reward , hidden_states_h_reward ), last_actions
159
+ ( agent_latent_states , global_latent_states ) , (hidden_states_c_reward , hidden_states_h_reward ), last_actions
155
160
)
161
+ network_output_agent_latent_state , network_output_global_latent_state = network_output .latent_state
156
162
157
- network_output .latent_state = to_detach_cpu_numpy (network_output .latent_state )
163
+ network_output_agent_latent_state = to_detach_cpu_numpy (network_output_agent_latent_state )
164
+ network_output_global_latent_state = to_detach_cpu_numpy (network_output_global_latent_state )
158
165
network_output .policy_logits = to_detach_cpu_numpy (network_output .policy_logits )
159
166
network_output .value = to_detach_cpu_numpy (self .inverse_scalar_transform_handle (network_output .value ))
160
167
network_output .value_prefix = to_detach_cpu_numpy (self .inverse_scalar_transform_handle (network_output .value_prefix ))
@@ -164,7 +171,8 @@ def search(
164
171
network_output .reward_hidden_state [1 ].detach ().cpu ().numpy ()
165
172
)
166
173
167
- latent_state_batch_in_search_path .append (network_output .latent_state )
174
+ agent_latent_state_batch_in_search_path .append (network_output_agent_latent_state )
175
+ global_latent_state_batch_in_search_path .append (network_output_global_latent_state )
168
176
# tolist() is to be compatible with cpp datatype.
169
177
value_prefix_batch = network_output .value_prefix .reshape (- 1 ).tolist ()
170
178
value_batch = network_output .value .reshape (- 1 ).tolist ()
@@ -273,7 +281,9 @@ def search(
273
281
batch_size = roots .num
274
282
pb_c_base , pb_c_init , discount_factor = self ._cfg .pb_c_base , self ._cfg .pb_c_init , self ._cfg .discount_factor
275
283
# the data storage of latent states: storing the latent state of all the nodes in the search.
276
- latent_state_batch_in_search_path = [latent_state_roots ]
284
+ agent_latent_state_roots , global_latent_state_roots = latent_state_roots
285
+ agent_latent_state_batch_in_search_path = [agent_latent_state_roots ]
286
+ global_latent_state_batch_in_search_path = [global_latent_state_roots ]
277
287
278
288
# minimax value storage
279
289
min_max_stats_lst = tree_muzero .MinMaxStatsList (batch_size )
@@ -282,7 +292,8 @@ def search(
282
292
for simulation_index in range (self ._cfg .num_simulations ):
283
293
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.
284
294
285
- latent_states = []
295
+ agent_latent_states = []
296
+ global_latent_states = []
286
297
287
298
# prepare a result wrapper to transport results between python and c++ parts
288
299
results = tree_muzero .ResultsWrapper (num = batch_size )
@@ -302,9 +313,11 @@ def search(
302
313
303
314
# obtain the latent state for leaf node
304
315
for ix , iy in zip (latent_state_index_in_search_path , latent_state_index_in_batch ):
305
- latent_states .append (latent_state_batch_in_search_path [ix ][iy ])
316
+ agent_latent_states .append (agent_latent_state_batch_in_search_path [ix ][iy ])
317
+ global_latent_states .append (global_latent_state_batch_in_search_path [ix ][iy ])
306
318
307
- latent_states = torch .from_numpy (np .asarray (latent_states )).to (self ._cfg .device ).float ()
319
+ agent_latent_states = torch .from_numpy (np .asarray (agent_latent_states )).to (self ._cfg .device ).float ()
320
+ global_latent_states = torch .from_numpy (np .asarray (global_latent_states )).to (self ._cfg .device ).float ()
308
321
# .long() is only for discrete action
309
322
last_actions = torch .from_numpy (np .asarray (last_actions )).to (self ._cfg .device ).long ()
310
323
"""
@@ -314,14 +327,19 @@ def search(
314
327
MCTS stage 3: Backup
315
328
At the end of the simulation, the statistics along the trajectory are updated.
316
329
"""
317
- network_output = model .recurrent_inference (latent_states , last_actions )
330
+ network_output = model .recurrent_inference (( agent_latent_states , global_latent_states ) , last_actions )
318
331
319
- network_output .latent_state = to_detach_cpu_numpy (network_output .latent_state )
332
+ network_output_agent_latent_state , network_output_global_latent_state = network_output .latent_state
333
+
334
+ # network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
335
+ network_output_agent_latent_state = to_detach_cpu_numpy (network_output_agent_latent_state )
336
+ network_output_global_latent_state = to_detach_cpu_numpy (network_output_global_latent_state )
320
337
network_output .policy_logits = to_detach_cpu_numpy (network_output .policy_logits )
321
338
network_output .value = to_detach_cpu_numpy (self .inverse_scalar_transform_handle (network_output .value ))
322
339
network_output .reward = to_detach_cpu_numpy (self .inverse_scalar_transform_handle (network_output .reward ))
323
340
324
- latent_state_batch_in_search_path .append (network_output .latent_state )
341
+ agent_latent_state_batch_in_search_path .append (network_output_agent_latent_state )
342
+ global_latent_state_batch_in_search_path .append (network_output_global_latent_state )
325
343
# tolist() is to be compatible with cpp datatype.
326
344
reward_batch = network_output .reward .reshape (- 1 ).tolist ()
327
345
value_batch = network_output .value .reshape (- 1 ).tolist ()
0 commit comments