We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4feeab6 commit 014acb7Copy full SHA for 014acb7
gemma/model.py
@@ -677,7 +677,7 @@ def generate(
677
curr_local_mask_tensor = local_mask_tensor.index_select(
678
2, input_positions_tensor
679
) if local_mask_tensor is not None else None
680
- output_positions_tensor = torch.LongTensor([min_prompt_len - 1], device=device)
+ output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
681
temperatures_tensor = None if not temperature else torch.FloatTensor(
682
[temperature] * batch_size).to(device)
683
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
0 commit comments