@@ -67,7 +67,7 @@ def causal_mask(b, h, q, kv):
67
67
68
68
def prefill (model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs ) -> torch .Tensor :
69
69
# input_pos: [B, S]
70
- mask = create_block_mask (causal_mask , 1 , 1 , input_pos .shape [0 ], model .max_seq_length , device = "cuda" )
70
+ mask = create_block_mask (causal_mask , 1 , 1 , input_pos .shape [0 ], model .max_seq_length , device = x . device )
71
71
logits = model (mask , x , input_pos )
72
72
return sample (logits , ** sampling_kwargs )[0 ]
73
73
@@ -77,11 +77,12 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
77
77
block_index = input_pos // block_mask .BLOCK_SIZE [0 ]
78
78
mask = block_mask [:, :, block_index ]
79
79
mask .mask_mod = block_mask .mask_mod
80
+ mask .seq_lengths = (1 , model .max_seq_length )
80
81
logits = model (mask , x , input_pos )
81
82
return sample (logits , ** sampling_kwargs )
82
83
83
84
def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , callback = lambda _ : _ , ** sampling_kwargs ):
84
- block_mask = create_block_mask (causal_mask , 1 , 1 , model .max_seq_length , model .max_seq_length , device = "cuda" )
85
+ block_mask = create_block_mask (causal_mask , 1 , 1 , model .max_seq_length , model .max_seq_length , device = cur_token . device )
85
86
new_tokens , new_probs = [], []
86
87
for i in range (num_new_tokens ):
87
88
next_token , next_prob = decode_one_token (
0 commit comments