@@ -26,7 +26,7 @@ class ModelArgs:
26
26
dim : int = 4096
27
27
intermediate_size : int = None
28
28
n_local_heads : int = - 1
29
- head_dim : int = 64
29
+ head_dim : int = None
30
30
rope_base : float = 10000
31
31
norm_eps : float = 1e-5
32
32
@@ -37,7 +37,8 @@ def __post_init__(self):
37
37
hidden_dim = 4 * self .dim
38
38
n_hidden = int (2 * hidden_dim / 3 )
39
39
self .intermediate_size = find_multiple (n_hidden , 256 )
40
- self .head_dim = self .dim // self .n_head
40
+ if self .head_dim is None :
41
+ self .head_dim = self .dim // self .n_head
41
42
42
43
@classmethod
43
44
def from_name (cls , name : str ):
@@ -51,6 +52,7 @@ def from_name(cls, name: str):
51
52
52
53
transformer_configs = {
53
54
"gemma-2b" : dict (dim = 2048 , vocab_size = 256000 , n_layer = 18 , n_head = 8 , n_local_heads = 1 , intermediate_size = 16384 ),
55
+ "gemma-7b" : dict (dim = 3072 , vocab_size = 256000 , n_layer = 28 , n_head = 16 , n_local_heads = 16 , intermediate_size = 24576 , head_dim = 256 ),
54
56
"CodeLlama-7b-Python-hf" : dict (block_size = 16384 , vocab_size = 32000 , n_layer = 32 , dim = 4096 , rope_base = 1000000 ),
55
57
"7B" : dict (n_layer = 32 , n_head = 32 , dim = 4096 ),
56
58
"13B" : dict (n_layer = 40 , n_head = 40 , dim = 5120 ),
@@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None:
95
97
def setup_caches (self , max_batch_size , max_seq_length ):
96
98
if self .max_seq_length >= max_seq_length and self .max_batch_size >= max_batch_size :
97
99
return
98
- head_dim = self .config .dim // self .config .n_head
99
100
max_seq_length = find_multiple (max_seq_length , 8 )
100
101
self .max_seq_length = max_seq_length
101
102
self .max_batch_size = max_batch_size
102
103
for b in self .layers :
103
- b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , head_dim )
104
+ b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , self . config . head_dim )
104
105
105
- self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self . config . n_head , self .config .rope_base )
106
+ self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .head_dim , self .config .rope_base )
106
107
self .causal_mask = torch .tril (torch .ones (self .max_seq_length , self .max_seq_length , dtype = torch .bool ))
107
108
108
109
def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
@@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs):
145
146
total_head_dim = (config .n_head + 2 * config .n_local_heads ) * config .head_dim
146
147
# key, query, value projections for all heads, but in a batch
147
148
self .wqkv = nn .Linear (config .dim , total_head_dim , bias = False )
148
- self .wo = nn .Linear (config .dim , config .dim , bias = False )
149
+ self .wo = nn .Linear (config .n_head * config . head_dim , config .dim , bias = False )
149
150
self .kv_cache = None
150
151
151
152
self .n_head = config .n_head
@@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
165
166
bsz , seqlen , _ = x .shape
166
167
167
168
kv_size = self .n_local_heads * self .head_dim
168
- q , k , v = self .wqkv (x ).split ([self .dim , kv_size , kv_size ], dim = - 1 )
169
+ q , k , v = self .wqkv (x ).split ([self .n_head * self . head_dim , kv_size , kv_size ], dim = - 1 )
169
170
170
171
q = q .view (bsz , seqlen , self .n_head , self .head_dim )
171
172
k = k .view (bsz , seqlen , self .n_local_heads , self .head_dim )
@@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
183
184
v = v .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
184
185
y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
185
186
186
- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
187
+ y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .n_head * self . head_dim )
187
188
188
189
y = self .wo (y )
189
190
return y
@@ -197,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None:
197
198
self .w2 = nn .Linear (config .intermediate_size , config .dim , bias = False )
198
199
199
200
def forward (self , x : Tensor ) -> Tensor :
200
- return self .w2 (F .gelu (self .w1 (x )) * self .w3 (x ))
201
+ return self .w2 (F .gelu (self .w1 (x ), approximate = "tanh" ) * self .w3 (x ))
201
202
202
203
203
204
class RMSNorm (nn .Module ):
0 commit comments