From 8d71f9627f50e7b1c3c4fb802ee741bd48eb0d00 Mon Sep 17 00:00:00 2001 From: "xindong.he" Date: Wed, 22 Nov 2023 03:36:27 +0000 Subject: [PATCH] fix(yzj):fix device bug --- lzero/policy/efficientzero.py | 1 + lzero/policy/muzero.py | 1 + 2 files changed, 2 insertions(+) diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 001b46294..114851463 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -401,6 +401,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # obtain the oracle latent states from representation function. beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze()) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device) network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index cbca2c18c..eb27d63ba 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -369,6 +369,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # obtain the oracle latent states from representation function. beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) obs_target_batch_tmp = default_collate(obs_target_batch[:, beg_index:end_index].squeeze()) + obs_target_batch_tmp = to_device(obs_target_batch_tmp, self._device) network_output = self._learn_model.initial_inference(obs_target_batch_tmp) latent_state = to_tensor(latent_state)