Skip to content

Commit d983e3d

Browse files
committed
remove divisible length constraints
1 parent e53ffb5 commit d983e3d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

generate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
7474
assert input_pos.shape[-1] == 1
7575
block_index = input_pos // block_mask.BLOCK_SIZE[0]
7676
mask = block_mask[:, :, block_index]
77+
mask.mask_mod = block_mask.mask_mod
7778
logits = model(mask, x, input_pos)
7879
return sample(logits, **sampling_kwargs)
7980

@@ -167,11 +168,10 @@ def generate(
167168
# create an empty tensor of the expected final shape and fill in the current tokens
168169
T = prompt.size(-1)
169170
T_new = T + max_new_tokens
170-
T_buf = roundup(T_new, 128) # round up to multiple of 128 to use flex_attention
171171
if interactive:
172-
max_seq_length = 384
172+
max_seq_length = 350
173173
else:
174-
max_seq_length = min(T_buf, model.config.block_size)
174+
max_seq_length = min(T_new, model.config.block_size)
175175

176176
device, dtype = prompt.device, prompt.dtype
177177
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
@@ -181,7 +181,7 @@ def generate(
181181
draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
182182

183183
# create an empty tensor of the expected final shape and fill in the current tokens
184-
empty = torch.empty(batch_size, T_buf, dtype=dtype, device=device)
184+
empty = torch.empty(batch_size, T_new, dtype=dtype, device=device)
185185
# We are just making the same prompt for every batch
186186
prompt = prompt.view(1, -1).repeat(batch_size, 1)
187187
empty[:, :T] = prompt
@@ -214,12 +214,12 @@ def generate(
214214
next_token = next_tokens[-1]
215215
else:
216216
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
217-
seq[:, T + 1:T_new] = torch.cat(generated_tokens, dim=-1)
217+
seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1)
218218

219219
generate_stats = {
220220
'accept_counts': accept_counts
221221
}
222-
return seq[:T_new], generate_stats
222+
return seq, generate_stats
223223

224224
def encode_tokens(tokenizer, string, bos=True, device=default_device):
225225
tokens = tokenizer.encode(string)

0 commit comments

Comments
 (0)