|
12 | 12 | import torch
|
13 | 13 | import torch._dynamo.config
|
14 | 14 | import torch._inductor.config
|
15 |
| -from torch.nn.attention.flex_attention import create_block_mask |
| 15 | +from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
16 | 16 |
|
17 | 17 | def device_sync(device):
|
18 | 18 | if "cuda" in device:
|
@@ -58,26 +58,29 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
58 | 58 | def roundup(val, multiplier):
|
59 | 59 | return ((val - 1) // multiplier + 1) * multiplier
|
60 | 60 |
|
| 61 | +def causal_mask(b, h, q, kv): |
| 62 | + return q >= kv |
| 63 | + |
61 | 64 | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
|
62 | 65 | # input_pos: [B, S]
|
63 |
| - mask = create_block_mask(model.get_mask_mod(0), 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda") |
| 66 | + mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda") |
64 | 67 | logits = model(mask, x, input_pos)
|
65 | 68 | return sample(logits, **sampling_kwargs)[0]
|
66 | 69 |
|
67 |
| -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: |
| 70 | +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: |
68 | 71 | # input_pos: [B, 1]
|
69 | 72 | assert input_pos.shape[-1] == 1
|
70 |
| - block_index = input_pos // model.block_mask.BLOCK_SIZE[0] |
71 |
| - mask = model.block_mask[:, :, block_index] |
72 |
| - mask.mask_mod = model.get_mask_mod(input_pos[0]) |
| 73 | + block_index = input_pos // block_mask.BLOCK_SIZE[0] |
| 74 | + mask = block_mask[:, :, block_index] |
73 | 75 | logits = model(mask, x, input_pos)
|
74 | 76 | return sample(logits, **sampling_kwargs)
|
75 | 77 |
|
76 | 78 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
|
| 79 | + block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device="cuda") |
77 | 80 | new_tokens, new_probs = [], []
|
78 | 81 | for i in range(num_new_tokens):
|
79 | 82 | next_token, next_prob = decode_one_token(
|
80 |
| - model, cur_token, input_pos, **sampling_kwargs |
| 83 | + model, cur_token, input_pos, block_mask, **sampling_kwargs |
81 | 84 | )
|
82 | 85 | input_pos += 1
|
83 | 86 | new_tokens.append(next_token.clone())
|
|
0 commit comments