@@ -67,7 +67,7 @@ def causal_mask(b, h, q, kv):
6767
6868def prefill (model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs ) -> torch .Tensor :
6969 # input_pos: [B, S]
70- mask = create_block_mask (causal_mask , 1 , 1 , input_pos .shape [0 ], model .max_seq_length , device = "cuda" )
70+ mask = create_block_mask (causal_mask , 1 , 1 , input_pos .shape [0 ], model .max_seq_length , device = x . device )
7171 logits = model (mask , x , input_pos )
7272 return sample (logits , ** sampling_kwargs )[0 ]
7373
@@ -77,11 +77,12 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
7777 block_index = input_pos // block_mask .BLOCK_SIZE [0 ]
7878 mask = block_mask [:, :, block_index ]
7979 mask .mask_mod = block_mask .mask_mod
80+ mask .seq_lengths = (1 , model .max_seq_length )
8081 logits = model (mask , x , input_pos )
8182 return sample (logits , ** sampling_kwargs )
8283
8384def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , callback = lambda _ : _ , ** sampling_kwargs ):
84- block_mask = create_block_mask (causal_mask , 1 , 1 , model .max_seq_length , model .max_seq_length , device = "cuda" )
85+ block_mask = create_block_mask (causal_mask , 1 , 1 , model .max_seq_length , model .max_seq_length , device = cur_token . device )
8586 new_tokens , new_probs = [], []
8687 for i in range (num_new_tokens ):
8788 next_token , next_prob = decode_one_token (
0 commit comments