Skip to content

Commit 7dd5661

Browse files
authored
Merge pull request #196 from pytorch-labs/bf/flex-decoding-integrate
Integrate Flex Decoding
2 parents 32971d3 + 8b91ce6 commit 7dd5661

File tree

3 files changed

+45
-21
lines changed

3 files changed

+45
-21
lines changed

eval.py

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
torch._dynamo.config.automatic_dynamic_shapes = True
1616
torch._inductor.config.triton.unique_kernel_names = True
1717
torch._inductor.config.epilogue_fusion = False
18-
torch._inductor.config.triton.cudagraphs = True
1918
torch._dynamo.config.cache_size_limit = 100000
2019

2120
from tokenizer import get_tokenizer

generate.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch._dynamo.config
1414
import torch._inductor.config
15+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
1516

1617
def device_sync(device):
1718
if "cuda" in device:
@@ -30,6 +31,8 @@ def device_sync(device):
3031

3132
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3233

34+
create_block_mask = torch.compile(create_block_mask)
35+
3336
# support running without installing as a package
3437
wd = Path(__file__).parent.parent.resolve()
3538
sys.path.append(str(wd))
@@ -56,29 +59,40 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
5659
idx_next = multinomial_sample_one_no_sync(probs)
5760
return idx_next, probs
5861

62+
def roundup(val, multiplier):
63+
return ((val - 1) // multiplier + 1) * multiplier
64+
65+
def causal_mask(b, h, q, kv):
66+
return q >= kv
67+
5968
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
6069
# input_pos: [B, S]
61-
logits = model(x, input_pos)
70+
mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device)
71+
logits = model(mask, x, input_pos)
6272
return sample(logits, **sampling_kwargs)[0]
6373

64-
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
74+
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
6575
# input_pos: [B, 1]
6676
assert input_pos.shape[-1] == 1
67-
logits = model(x, input_pos)
77+
block_index = input_pos // block_mask.BLOCK_SIZE[0]
78+
mask = block_mask[:, :, block_index]
79+
mask.mask_mod = block_mask.mask_mod
80+
mask.seq_lengths = (1, model.max_seq_length)
81+
logits = model(mask, x, input_pos)
6882
return sample(logits, **sampling_kwargs)
6983

7084
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
85+
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device)
7186
new_tokens, new_probs = [], []
7287
for i in range(num_new_tokens):
73-
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
74-
next_token, next_prob = decode_one_token(
75-
model, cur_token, input_pos, **sampling_kwargs
76-
)
77-
input_pos += 1
78-
new_tokens.append(next_token.clone())
79-
callback(new_tokens[-1])
80-
new_probs.append(next_prob.clone())
81-
cur_token = next_token.clone()
88+
next_token, next_prob = decode_one_token(
89+
model, cur_token, input_pos, block_mask, **sampling_kwargs
90+
)
91+
input_pos += 1
92+
new_tokens.append(next_token.clone())
93+
callback(new_tokens[-1])
94+
new_probs.append(next_prob.clone())
95+
cur_token = next_token.clone()
8296

8397
return new_tokens, new_probs
8498

model.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,26 @@
1111
import torch.nn as nn
1212
from torch import Tensor
1313
from torch.nn import functional as F
14+
from torch.nn.attention.flex_attention import (
15+
_mask_mod_signature,
16+
BlockMask,
17+
flex_attention,
18+
)
1419

1520

1621
def find_multiple(n: int, k: int) -> int:
1722
if n % k == 0:
1823
return n
1924
return n + k - (n % k)
2025

26+
27+
def get_mask_mod(mask_mod: _mask_mod_signature, offset: int):
28+
def _mask_mod(b, h, q, kv):
29+
return mask_mod(b, h, q + offset, kv)
30+
31+
return _mask_mod
32+
33+
2134
@dataclass
2235
class ModelArgs:
2336
block_size: int = 2048
@@ -113,6 +126,7 @@ def __init__(self, config: ModelArgs) -> None:
113126
self.mask_cache: Optional[Tensor] = None
114127
self.max_batch_size = -1
115128
self.max_seq_length = -1
129+
self.get_mask_mod = get_mask_mod
116130

117131
def setup_caches(self, max_batch_size, max_seq_length):
118132
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
@@ -131,11 +145,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
131145
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
132146

133147
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling)
134-
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
135148

136-
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
149+
def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
137150
assert self.freqs_cis is not None, "Caches must be initialized first"
138-
mask = self.causal_mask[None, None, input_pos]
151+
mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0])
139152
freqs_cis = self.freqs_cis[input_pos]
140153
x = self.tok_embeddings(idx)
141154

@@ -158,7 +171,7 @@ def __init__(self, config: ModelArgs) -> None:
158171
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
159172
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
160173

161-
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
174+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
162175
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
163176
out = h + self.feed_forward(self.ffn_norm(h))
164177
return out
@@ -188,7 +201,7 @@ def load_hook(self, state_dict, prefix, *args):
188201
wv = state_dict.pop(prefix + "wv.weight")
189202
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
190203

191-
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
204+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
192205
bsz, seqlen, _ = x.shape
193206

194207
kv_size = self.n_local_heads * self.head_dim
@@ -206,9 +219,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
206219
if self.kv_cache is not None:
207220
k, v = self.kv_cache.update(input_pos, k, v)
208221

209-
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
210-
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
211-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
222+
y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))
212223

213224
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
214225

0 commit comments

Comments
 (0)