Skip to content

Commit 014acb7

Browse files
1B working in GPUs now
1 parent 4feeab6 commit 014acb7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

gemma/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def generate(
677677
curr_local_mask_tensor = local_mask_tensor.index_select(
678678
2, input_positions_tensor
679679
) if local_mask_tensor is not None else None
680-
output_positions_tensor = torch.LongTensor([min_prompt_len - 1], device=device)
680+
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
681681
temperatures_tensor = None if not temperature else torch.FloatTensor(
682682
[temperature] * batch_size).to(device)
683683
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)

0 commit comments

Comments
 (0)