Skip to content

Commit

Permalink
fix(yzj):fix device bug
Browse files Browse the repository at this point in the history
  • Loading branch information
xindong.he committed Nov 22, 2023
1 parent 6ea3f9b commit 8d71f96
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
# obtain the oracle latent states from representation function.
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)

latent_state = to_tensor(latent_state)
Expand Down
1 change: 1 addition & 0 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# obtain the oracle latent states from representation function.
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze())
obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device)
network_output = self._learn_model.initial_inference(obs_target_batch_tmp)

latent_state = to_tensor(latent_state)
Expand Down

0 comments on commit 8d71f96

Please sign in to comment.