11
11
import torch .nn as nn
12
12
from torch import Tensor
13
13
from torch .nn import functional as F
14
+ from torch .nn .attention .flex_attention import (
15
+ _mask_mod_signature ,
16
+ BlockMask ,
17
+ flex_attention ,
18
+ )
14
19
15
20
16
21
def find_multiple (n : int , k : int ) -> int :
17
22
if n % k == 0 :
18
23
return n
19
24
return n + k - (n % k )
20
25
26
+
27
+ def get_mask_mod (mask_mod : _mask_mod_signature , offset : int ):
28
+ def _mask_mod (b , h , q , kv ):
29
+ return mask_mod (b , h , q + offset , kv )
30
+
31
+ return _mask_mod
32
+
33
+
21
34
@dataclass
22
35
class ModelArgs :
23
36
block_size : int = 2048
@@ -113,6 +126,7 @@ def __init__(self, config: ModelArgs) -> None:
113
126
self .mask_cache : Optional [Tensor ] = None
114
127
self .max_batch_size = - 1
115
128
self .max_seq_length = - 1
129
+ self .get_mask_mod = get_mask_mod
116
130
117
131
def setup_caches (self , max_batch_size , max_seq_length ):
118
132
if self .max_seq_length >= max_seq_length and self .max_batch_size >= max_batch_size :
@@ -131,11 +145,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
131
145
b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , head_dim , dtype )
132
146
133
147
self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self .config .n_head , self .config .rope_base , dtype , self .config .rope_scaling )
134
- self .causal_mask = torch .tril (torch .ones (self .max_seq_length , self .max_seq_length , dtype = torch .bool ))
135
148
136
- def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
149
+ def forward (self , mask : BlockMask , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
137
150
assert self .freqs_cis is not None , "Caches must be initialized first"
138
- mask = self .causal_mask [ None , None , input_pos ]
151
+ mask . mask_mod = self .get_mask_mod ( mask . mask_mod , input_pos [ 0 ])
139
152
freqs_cis = self .freqs_cis [input_pos ]
140
153
x = self .tok_embeddings (idx )
141
154
@@ -158,7 +171,7 @@ def __init__(self, config: ModelArgs) -> None:
158
171
self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
159
172
self .attention_norm = RMSNorm (config .dim , config .norm_eps )
160
173
161
- def forward (self , x : Tensor , input_pos : Tensor , freqs_cis : Tensor , mask : Tensor ) -> Tensor :
174
+ def forward (self , x : Tensor , input_pos : Tensor , freqs_cis : Tensor , mask : BlockMask ) -> Tensor :
162
175
h = x + self .attention (self .attention_norm (x ), freqs_cis , mask , input_pos )
163
176
out = h + self .feed_forward (self .ffn_norm (h ))
164
177
return out
@@ -188,7 +201,7 @@ def load_hook(self, state_dict, prefix, *args):
188
201
wv = state_dict .pop (prefix + "wv.weight" )
189
202
state_dict [prefix + "wqkv.weight" ] = torch .cat ([wq , wk , wv ])
190
203
191
- def forward (self , x : Tensor , freqs_cis : Tensor , mask : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
204
+ def forward (self , x : Tensor , freqs_cis : Tensor , mask : BlockMask , input_pos : Optional [Tensor ] = None ) -> Tensor :
192
205
bsz , seqlen , _ = x .shape
193
206
194
207
kv_size = self .n_local_heads * self .head_dim
@@ -206,9 +219,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
206
219
if self .kv_cache is not None :
207
220
k , v = self .kv_cache .update (input_pos , k , v )
208
221
209
- k = k .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
210
- v = v .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
211
- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
222
+ y = flex_attention (q , k , v , block_mask = mask , enable_gqa = (self .n_head != self .n_local_heads ))
212
223
213
224
y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
214
225
0 commit comments