1111import torch .nn as nn
1212from torch import Tensor
1313from torch .nn import functional as F
14+ from torch .nn .attention .flex_attention import (
15+ _mask_mod_signature ,
16+ BlockMask ,
17+ flex_attention ,
18+ )
1419
1520
1621def find_multiple (n : int , k : int ) -> int :
1722 if n % k == 0 :
1823 return n
1924 return n + k - (n % k )
2025
26+
27+ def get_mask_mod (mask_mod : _mask_mod_signature , offset : int ):
28+ def _mask_mod (b , h , q , kv ):
29+ return mask_mod (b , h , q + offset , kv )
30+
31+ return _mask_mod
32+
33+
2134@dataclass
2235class ModelArgs :
2336 block_size : int = 2048
@@ -113,6 +126,7 @@ def __init__(self, config: ModelArgs) -> None:
113126 self .mask_cache : Optional [Tensor ] = None
114127 self .max_batch_size = - 1
115128 self .max_seq_length = - 1
129+ self .get_mask_mod = get_mask_mod
116130
117131 def setup_caches (self , max_batch_size , max_seq_length ):
118132 if self .max_seq_length >= max_seq_length and self .max_batch_size >= max_batch_size :
@@ -131,11 +145,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
131145 b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , head_dim , dtype )
132146
133147 self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self .config .n_head , self .config .rope_base , dtype , self .config .rope_scaling )
134- self .causal_mask = torch .tril (torch .ones (self .max_seq_length , self .max_seq_length , dtype = torch .bool ))
135148
136- def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
149+ def forward (self , mask : BlockMask , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
137150 assert self .freqs_cis is not None , "Caches must be initialized first"
138- mask = self .causal_mask [ None , None , input_pos ]
151+ mask . mask_mod = self .get_mask_mod ( mask . mask_mod , input_pos [ 0 ])
139152 freqs_cis = self .freqs_cis [input_pos ]
140153 x = self .tok_embeddings (idx )
141154
@@ -158,7 +171,7 @@ def __init__(self, config: ModelArgs) -> None:
158171 self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
159172 self .attention_norm = RMSNorm (config .dim , config .norm_eps )
160173
161- def forward (self , x : Tensor , input_pos : Tensor , freqs_cis : Tensor , mask : Tensor ) -> Tensor :
174+ def forward (self , x : Tensor , input_pos : Tensor , freqs_cis : Tensor , mask : BlockMask ) -> Tensor :
162175 h = x + self .attention (self .attention_norm (x ), freqs_cis , mask , input_pos )
163176 out = h + self .feed_forward (self .ffn_norm (h ))
164177 return out
@@ -188,7 +201,7 @@ def load_hook(self, state_dict, prefix, *args):
188201 wv = state_dict .pop (prefix + "wv.weight" )
189202 state_dict [prefix + "wqkv.weight" ] = torch .cat ([wq , wk , wv ])
190203
191- def forward (self , x : Tensor , freqs_cis : Tensor , mask : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
204+ def forward (self , x : Tensor , freqs_cis : Tensor , mask : BlockMask , input_pos : Optional [Tensor ] = None ) -> Tensor :
192205 bsz , seqlen , _ = x .shape
193206
194207 kv_size = self .n_local_heads * self .head_dim
@@ -206,9 +219,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
206219 if self .kv_cache is not None :
207220 k , v = self .kv_cache .update (input_pos , k , v )
208221
209- k = k .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
210- v = v .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
211- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
222+ y = flex_attention (q , k , v , block_mask = mask , enable_gqa = (self .n_head != self .n_local_heads ))
212223
213224 y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
214225
0 commit comments