Skip to content

Commit

Permalink
fix(yzj): fix device bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chosenone authored and chosenone committed Nov 26, 2023
1 parent 59c7c56 commit 829d86d
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ namespace tree
for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
{
#ifdef __APPLE__
disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter]));
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#else
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
Expand Down
1 change: 1 addition & 0 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
# obtain the oracle latent states from representation function.
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)

latent_state = to_tensor(latent_state)
Expand Down
1 change: 1 addition & 0 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# obtain the oracle latent states from representation function.
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)

latent_state = to_tensor(latent_state)
Expand Down

0 comments on commit 829d86d

Please sign in to comment.