Skip to content

Commit a6491b3

Browse files
committed
Fix device of channel_dirs
1 parent a7d126a commit a6491b3

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/reward_preprocessing/vis/reward_vis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ def __init__(
223223
)
224224
# Transform into torch tensor instead of numpy array, because this is expected
225225
# later on.
226-
self.channel_dirs = th.tensor(self.channel_dirs)
226+
self.channel_dirs = th.tensor(self.channel_dirs).to(
227+
self.model_inputs_full.device
228+
)
227229

228230
def vis_traditional(
229231
self,

0 commit comments

Comments
 (0)