Skip to content

Commit 4feeab6

Browse files
Gemma 1B: adding device for local_mask_tensor
1 parent de621eb commit 4feeab6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

gemma/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def generate(
670670
-2.3819763e38).to(torch.float)
671671
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
672672
local_mask_tensor = mask_tensor + torch.tril(
673-
torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38),
673+
torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38, device=device),
674674
diagonal=-self.config.sliding_window_size,
675675
) if self.config.sliding_window_size else None
676676
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)

0 commit comments

Comments
 (0)