16
16
17
17
import dataclasses
18
18
import enum
19
+ import os
19
20
import torch
20
21
from typing import Optional , Sequence
22
+ from .siglip_vision import config as siglip_vision_config
21
23
22
24
23
25
# Keep a mapping from dtype strings to the supported torch dtypes.
@@ -37,6 +39,7 @@ class AttentionType(enum.Enum):
37
39
class Architecture (enum .Enum ):
38
40
GEMMA_1 = 1
39
41
GEMMA_2 = 2
42
+ GEMMA_3 = 3
40
43
41
44
42
45
@dataclasses .dataclass
@@ -66,7 +69,9 @@ class GemmaConfig:
66
69
# Whether a quantized version of the model is used.
67
70
quant : bool = False
68
71
# The path to the model tokenizer.
69
- tokenizer : Optional [str ] = 'tokenizer/tokenizer.model'
72
+ tokenizer : Optional [str ] = (
73
+ 'tokenizer/tokenizer.model'
74
+ )
70
75
# The types of attention used in the layers of the model.
71
76
attn_types : Optional [Sequence [AttentionType ]] = None
72
77
# The size of the sliding window used for local attention.
@@ -82,28 +87,38 @@ class GemmaConfig:
82
87
use_pre_ffw_norm : bool = False
83
88
# Whether to use post mlp normalization.
84
89
use_post_ffw_norm : bool = False
90
+ # The wave length of the rotary embedding.
91
+ rope_wave_length : dict [AttentionType , int ] | None = None
92
+ # Whether to use QK normalization in the attention blocks.
93
+ use_qk_norm : bool = False
94
+ # Vision model config.
95
+ vision_config : siglip_vision_config .SiglipVisionModelConfig | None = None
96
+ # The factor by which the rope wave length is divided for global layers.
97
+ rope_scaling_factor : int | None = None
85
98
86
99
def get_dtype (self ) -> Optional [torch .dtype ]:
87
100
"""Gets the torch dtype from the config dtype string."""
88
101
return _STR_DTYPE_TO_TORCH_DTYPE .get (self .dtype , None )
89
102
90
103
91
- def get_config_for_7b () -> GemmaConfig :
92
- return GemmaConfig ()
104
+ def get_config_for_7b (dtype : str = 'bfloat16' ) -> GemmaConfig :
105
+ return GemmaConfig (dtype = dtype )
93
106
94
107
95
- def get_config_for_2b () -> GemmaConfig :
108
+ def get_config_for_2b (dtype : str = 'bfloat16' ) -> GemmaConfig :
96
109
return GemmaConfig (
110
+ dtype = dtype ,
97
111
num_hidden_layers = 18 ,
98
112
num_attention_heads = 8 ,
99
113
num_key_value_heads = 1 ,
100
114
hidden_size = 2048 ,
101
- intermediate_size = 16384
115
+ intermediate_size = 16384 ,
102
116
)
103
117
104
118
105
- def get_config_for_2b_v2 () -> GemmaConfig :
119
+ def get_config_for_2b_v2 (dtype : str = 'bfloat16' ) -> GemmaConfig :
106
120
return GemmaConfig (
121
+ dtype = dtype ,
107
122
architecture = Architecture .GEMMA_2 ,
108
123
num_hidden_layers = 26 ,
109
124
num_attention_heads = 8 ,
@@ -120,8 +135,9 @@ def get_config_for_2b_v2() -> GemmaConfig:
120
135
)
121
136
122
137
123
- def get_config_for_9b () -> GemmaConfig :
138
+ def get_config_for_9b (dtype : str = 'bfloat16' ) -> GemmaConfig :
124
139
return GemmaConfig (
140
+ dtype = dtype ,
125
141
architecture = Architecture .GEMMA_2 ,
126
142
num_hidden_layers = 42 ,
127
143
num_attention_heads = 16 ,
@@ -138,38 +154,187 @@ def get_config_for_9b() -> GemmaConfig:
138
154
)
139
155
140
156
141
- def get_config_for_27b () -> GemmaConfig :
142
- return GemmaConfig (
143
- architecture = Architecture .GEMMA_2 ,
144
- num_hidden_layers = 46 ,
145
- num_attention_heads = 32 ,
146
- num_key_value_heads = 16 ,
147
- hidden_size = 4608 ,
148
- intermediate_size = 36864 ,
149
- use_pre_ffw_norm = True ,
150
- use_post_ffw_norm = True ,
151
- final_logit_softcapping = 30.0 ,
152
- attn_logit_softcapping = 50.0 ,
153
- head_dim = 128 ,
154
- attn_types = [AttentionType .LOCAL_SLIDING , AttentionType .GLOBAL ] * 23 ,
155
- sliding_window_size = 4096 ,
156
- query_pre_attn_scalar = 144 , # hidden_size / num_attention_heads
157
- )
157
+ def get_config_for_27b (dtype : str = 'bfloat16' ) -> GemmaConfig :
158
+ return GemmaConfig (
159
+ dtype = dtype ,
160
+ architecture = Architecture .GEMMA_2 ,
161
+ num_hidden_layers = 46 ,
162
+ num_attention_heads = 32 ,
163
+ num_key_value_heads = 16 ,
164
+ hidden_size = 4608 ,
165
+ intermediate_size = 36864 ,
166
+ use_pre_ffw_norm = True ,
167
+ use_post_ffw_norm = True ,
168
+ final_logit_softcapping = 30.0 ,
169
+ attn_logit_softcapping = 50.0 ,
170
+ head_dim = 128 ,
171
+ attn_types = [AttentionType .LOCAL_SLIDING , AttentionType .GLOBAL ] * 23 ,
172
+ sliding_window_size = 4096 ,
173
+ query_pre_attn_scalar = 144 , # hidden_size / num_attention_heads
174
+ )
158
175
159
176
160
- def get_model_config (variant : str ) -> GemmaConfig :
161
- if variant == '7b' :
162
- return get_config_for_7b ()
163
- elif variant == '2b' :
164
- return get_config_for_2b ()
165
- elif variant == '2b-v2' :
166
- return get_config_for_2b_v2 ()
167
- elif variant == '9b' :
168
- return get_config_for_9b ()
169
- elif variant == '27b' :
170
- return get_config_for_27b ()
171
- else :
172
- raise ValueError (
173
- f'Invalid variant { variant } . Supported variants are "2b"'
174
- 'and "7b" and "9b" and "27b".' )
177
+ def get_config_for_1b (dtype : str ) -> GemmaConfig :
178
+ return GemmaConfig (
179
+ dtype = dtype ,
180
+ architecture = Architecture .GEMMA_3 ,
181
+ num_hidden_layers = 26 ,
182
+ num_attention_heads = 4 ,
183
+ num_key_value_heads = 1 ,
184
+ hidden_size = 1152 ,
185
+ intermediate_size = 6912 ,
186
+ use_pre_ffw_norm = True ,
187
+ use_post_ffw_norm = True ,
188
+ head_dim = 256 ,
189
+ attn_types = (
190
+ AttentionType .LOCAL_SLIDING ,
191
+ AttentionType .LOCAL_SLIDING ,
192
+ AttentionType .LOCAL_SLIDING ,
193
+ AttentionType .LOCAL_SLIDING ,
194
+ AttentionType .LOCAL_SLIDING ,
195
+ AttentionType .GLOBAL ,
196
+ ),
197
+ sliding_window_size = 512 ,
198
+ rope_wave_length = {
199
+ AttentionType .LOCAL_SLIDING : 10_000 ,
200
+ AttentionType .GLOBAL : 1_000_000 ,
201
+ },
202
+ vocab_size = 262_144 ,
203
+ max_position_embeddings = 32_768 ,
204
+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
205
+ use_qk_norm = True ,
206
+ vision_config = None ,
207
+ )
175
208
209
+
210
+ def get_config_for_4b (dtype : str ) -> GemmaConfig :
211
+ return GemmaConfig (
212
+ dtype = dtype ,
213
+ architecture = Architecture .GEMMA_3 ,
214
+ num_hidden_layers = 34 ,
215
+ num_attention_heads = 8 ,
216
+ num_key_value_heads = 4 ,
217
+ hidden_size = 2560 ,
218
+ intermediate_size = 10240 ,
219
+ use_pre_ffw_norm = True ,
220
+ use_post_ffw_norm = True ,
221
+ head_dim = 256 ,
222
+ attn_types = (
223
+ AttentionType .LOCAL_SLIDING ,
224
+ AttentionType .LOCAL_SLIDING ,
225
+ AttentionType .LOCAL_SLIDING ,
226
+ AttentionType .LOCAL_SLIDING ,
227
+ AttentionType .LOCAL_SLIDING ,
228
+ AttentionType .GLOBAL ,
229
+ ),
230
+ sliding_window_size = 1024 ,
231
+ rope_wave_length = {
232
+ AttentionType .LOCAL_SLIDING : 10_000 ,
233
+ AttentionType .GLOBAL : 1_000_000 ,
234
+ },
235
+ vocab_size = 262_144 ,
236
+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
237
+ use_qk_norm = True ,
238
+ vision_config = siglip_vision_config .get_siglip_vision_model_config (),
239
+ rope_scaling_factor = 8 ,
240
+ )
241
+
242
+
243
+ def get_config_for_12b (dtype : str ) -> GemmaConfig :
244
+ return GemmaConfig (
245
+ dtype = dtype ,
246
+ architecture = Architecture .GEMMA_3 ,
247
+ num_hidden_layers = 48 ,
248
+ num_attention_heads = 16 ,
249
+ num_key_value_heads = 8 ,
250
+ hidden_size = 3840 ,
251
+ intermediate_size = 3840 * 8 // 2 ,
252
+ use_pre_ffw_norm = True ,
253
+ use_post_ffw_norm = True ,
254
+ head_dim = 256 ,
255
+ attn_types = (
256
+ AttentionType .LOCAL_SLIDING ,
257
+ AttentionType .LOCAL_SLIDING ,
258
+ AttentionType .LOCAL_SLIDING ,
259
+ AttentionType .LOCAL_SLIDING ,
260
+ AttentionType .LOCAL_SLIDING ,
261
+ AttentionType .GLOBAL ,
262
+ ),
263
+ sliding_window_size = 1024 ,
264
+ rope_wave_length = {
265
+ AttentionType .LOCAL_SLIDING : 10_000 ,
266
+ AttentionType .GLOBAL : 1_000_000 ,
267
+ },
268
+ vocab_size = 262_144 ,
269
+ max_position_embeddings = 131_072 ,
270
+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
271
+ use_qk_norm = True ,
272
+ vision_config = siglip_vision_config .get_siglip_vision_model_config (),
273
+ rope_scaling_factor = 8 ,
274
+ )
275
+
276
+
277
+ def get_config_for_27b_v3 (dtype : str ) -> GemmaConfig :
278
+ return GemmaConfig (
279
+ dtype = dtype ,
280
+ architecture = Architecture .GEMMA_3 ,
281
+ num_hidden_layers = 62 ,
282
+ num_attention_heads = 32 ,
283
+ num_key_value_heads = 16 ,
284
+ hidden_size = 5376 ,
285
+ intermediate_size = 5376 * 8 // 2 ,
286
+ use_pre_ffw_norm = True ,
287
+ use_post_ffw_norm = True ,
288
+ head_dim = 128 ,
289
+ query_pre_attn_scalar = 5376 // 32 ,
290
+ attn_types = (
291
+ AttentionType .LOCAL_SLIDING ,
292
+ AttentionType .LOCAL_SLIDING ,
293
+ AttentionType .LOCAL_SLIDING ,
294
+ AttentionType .LOCAL_SLIDING ,
295
+ AttentionType .LOCAL_SLIDING ,
296
+ AttentionType .GLOBAL ,
297
+ ),
298
+ sliding_window_size = 1024 ,
299
+ rope_wave_length = {
300
+ AttentionType .LOCAL_SLIDING : 10_000 ,
301
+ AttentionType .GLOBAL : 1_000_000 ,
302
+ },
303
+ vocab_size = 262_144 ,
304
+ max_position_embeddings = 131_072 ,
305
+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
306
+ use_qk_norm = True ,
307
+ vision_config = siglip_vision_config .get_siglip_vision_model_config (),
308
+ rope_scaling_factor = 8 ,
309
+ )
310
+
311
+
312
+ def get_model_config (variant : str , dtype : str = 'bfloat16' ) -> GemmaConfig :
313
+ """Gets the GemmaConfig for the diresired variant and dtype."""
314
+ # Gemma1 variants
315
+ if variant == '7b' :
316
+ return get_config_for_7b (dtype )
317
+ elif variant == '2b' :
318
+ return get_config_for_2b (dtype )
319
+ # Gemma2 variants
320
+ elif variant == '2b-v2' :
321
+ return get_config_for_2b_v2 (dtype )
322
+ elif variant == '9b' :
323
+ return get_config_for_9b (dtype )
324
+ elif variant == '27b' :
325
+ return get_config_for_27b (dtype )
326
+ # Gemma3 variants
327
+ elif variant == '1b' :
328
+ return get_config_for_1b (dtype )
329
+ elif variant == '4b' :
330
+ return get_config_for_4b (dtype )
331
+ elif variant == '12b' :
332
+ return get_config_for_12b (dtype )
333
+ elif variant == '27b_v3' :
334
+ return get_config_for_27b_v3 (dtype )
335
+ # Invalid variants
336
+ else :
337
+ raise ValueError (
338
+ f'Invalid variant { variant } . Supported variants are "1b", "2b", '
339
+ '"2b-v2", "4b",, "7b", "9b" "12b", "27b", and "27b_v3".'
340
+ )
0 commit comments