Skip to content

Commit 4e917f5

Browse files
committed
fix comments
1 parent 3d0adb5 commit 4e917f5

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

generate.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch._dynamo.config
1414
import torch._inductor.config
15-
from torch.nn.attention.flex_attention import create_block_mask
15+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
1616

1717
def device_sync(device):
1818
if "cuda" in device:
@@ -58,26 +58,29 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
5858
def roundup(val, multiplier):
5959
return ((val - 1) // multiplier + 1) * multiplier
6060

61+
def causal_mask(b, h, q, kv):
62+
return q >= kv
63+
6164
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
6265
# input_pos: [B, S]
63-
mask = create_block_mask(model.get_mask_mod(0), 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda")
66+
mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda")
6467
logits = model(mask, x, input_pos)
6568
return sample(logits, **sampling_kwargs)[0]
6669

67-
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
70+
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
6871
# input_pos: [B, 1]
6972
assert input_pos.shape[-1] == 1
70-
block_index = input_pos // model.block_mask.BLOCK_SIZE[0]
71-
mask = model.block_mask[:, :, block_index]
72-
mask.mask_mod = model.get_mask_mod(input_pos[0])
73+
block_index = input_pos // block_mask.BLOCK_SIZE[0]
74+
mask = block_mask[:, :, block_index]
7375
logits = model(mask, x, input_pos)
7476
return sample(logits, **sampling_kwargs)
7577

7678
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
79+
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device="cuda")
7780
new_tokens, new_probs = [], []
7881
for i in range(num_new_tokens):
7982
next_token, next_prob = decode_one_token(
80-
model, cur_token, input_pos, **sampling_kwargs
83+
model, cur_token, input_pos, block_mask, **sampling_kwargs
8184
)
8285
input_pos += 1
8386
new_tokens.append(next_token.clone())

model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def find_multiple(n: int, k: int) -> int:
2424
return n + k - (n % k)
2525

2626

27-
def get_causal_mask(offset):
28-
def causal_mask(b, h, q, kv):
29-
return offset + q >= kv
27+
def get_mask_mod(mask_mod: _mask_mod_signature, offset: int):
28+
def _mask_mod(b, h, q, kv):
29+
return mask_mod(b, h, q + offset, kv)
3030

31-
return causal_mask
31+
return _mask_mod
3232

3333

3434
@dataclass
@@ -103,7 +103,7 @@ def update(self, input_pos, k_val, v_val):
103103
return k_out, v_out
104104

105105
class Transformer(nn.Module):
106-
def __init__(self, config: ModelArgs, get_mask_mod: Callable[[int], _mask_mod_signature]) -> None:
106+
def __init__(self, config: ModelArgs) -> None:
107107
super().__init__()
108108
self.config = config
109109

@@ -135,10 +135,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
135135
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
136136

137137
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
138-
self.block_mask = create_block_mask(self.get_mask_mod(0), 1, 1, max_seq_length, max_seq_length, device="cuda")
139138

140139
def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
141140
assert self.freqs_cis is not None, "Caches must be initialized first"
141+
mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0])
142142
freqs_cis = self.freqs_cis[input_pos]
143143
x = self.tok_embeddings(idx)
144144

@@ -150,7 +150,7 @@ def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = No
150150

151151
@classmethod
152152
def from_name(cls, name: str):
153-
return cls(ModelArgs.from_name(name), get_causal_mask)
153+
return cls(ModelArgs.from_name(name))
154154

155155

156156
class TransformerBlock(nn.Module):

0 commit comments

Comments
 (0)