@@ -50,6 +50,7 @@ def from_name(cls, name: str):
50
50
51
51
52
52
transformer_configs = {
53
+ "gemma-2b" : dict (dim = 2048 , vocab_size = 256000 , n_layer = 18 , n_head = 8 , n_local_heads = 1 , intermediate_size = 16384 ),
53
54
"CodeLlama-7b-Python-hf" : dict (block_size = 16384 , vocab_size = 32000 , n_layer = 32 , dim = 4096 , rope_base = 1000000 ),
54
55
"7B" : dict (n_layer = 32 , n_head = 32 , dim = 4096 ),
55
56
"13B" : dict (n_layer = 40 , n_head = 40 , dim = 5120 ),
@@ -109,6 +110,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
109
110
mask = self .causal_mask [None , None , input_pos ]
110
111
freqs_cis = self .freqs_cis [input_pos ]
111
112
x = self .tok_embeddings (idx )
113
+ x = (self .config .dim ** 0.5 ) * x
112
114
113
115
for i , layer in enumerate (self .layers ):
114
116
x = layer (x , input_pos , freqs_cis , mask )
@@ -195,7 +197,7 @@ def __init__(self, config: ModelArgs) -> None:
195
197
self .w2 = nn .Linear (config .intermediate_size , config .dim , bias = False )
196
198
197
199
def forward (self , x : Tensor ) -> Tensor :
198
- return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
200
+ return self .w2 (F .gelu (self .w1 (x )) * self .w3 (x ))
199
201
200
202
201
203
class RMSNorm (nn .Module ):
@@ -209,7 +211,7 @@ def _norm(self, x):
209
211
210
212
def forward (self , x : Tensor ) -> Tensor :
211
213
output = self ._norm (x .float ()).type_as (x )
212
- return output * self .weight
214
+ return output * ( 1 + self .weight )
213
215
214
216
215
217
def precompute_freqs_cis (
0 commit comments