@@ -74,6 +74,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
74
74
assert input_pos .shape [- 1 ] == 1
75
75
block_index = input_pos // block_mask .BLOCK_SIZE [0 ]
76
76
mask = block_mask [:, :, block_index ]
77
+ mask .mask_mod = block_mask .mask_mod
77
78
logits = model (mask , x , input_pos )
78
79
return sample (logits , ** sampling_kwargs )
79
80
@@ -167,11 +168,10 @@ def generate(
167
168
# create an empty tensor of the expected final shape and fill in the current tokens
168
169
T = prompt .size (- 1 )
169
170
T_new = T + max_new_tokens
170
- T_buf = roundup (T_new , 128 ) # round up to multiple of 128 to use flex_attention
171
171
if interactive :
172
- max_seq_length = 384
172
+ max_seq_length = 350
173
173
else :
174
- max_seq_length = min (T_buf , model .config .block_size )
174
+ max_seq_length = min (T_new , model .config .block_size )
175
175
176
176
device , dtype = prompt .device , prompt .dtype
177
177
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
@@ -181,7 +181,7 @@ def generate(
181
181
draft_model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
182
182
183
183
# create an empty tensor of the expected final shape and fill in the current tokens
184
- empty = torch .empty (batch_size , T_buf , dtype = dtype , device = device )
184
+ empty = torch .empty (batch_size , T_new , dtype = dtype , device = device )
185
185
# We are just making the same prompt for every batch
186
186
prompt = prompt .view (1 , - 1 ).repeat (batch_size , 1 )
187
187
empty [:, :T ] = prompt
@@ -214,12 +214,12 @@ def generate(
214
214
next_token = next_tokens [- 1 ]
215
215
else :
216
216
generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
217
- seq [:, T + 1 :T_new ] = torch .cat (generated_tokens , dim = - 1 )
217
+ seq [:, T + 1 :] = torch .cat (generated_tokens , dim = - 1 )
218
218
219
219
generate_stats = {
220
220
'accept_counts' : accept_counts
221
221
}
222
- return seq [: T_new ] , generate_stats
222
+ return seq , generate_stats
223
223
224
224
def encode_tokens (tokenizer , string , bos = True , device = default_device ):
225
225
tokens = tokenizer .encode (string )
0 commit comments