1616
1717import dataclasses
1818import enum
19+ import os
1920import torch
2021from typing import Optional , Sequence
22+ from .siglip_vision import config as siglip_vision_config
2123
2224
2325# Keep a mapping from dtype strings to the supported torch dtypes.
@@ -37,6 +39,7 @@ class AttentionType(enum.Enum):
3739class Architecture (enum .Enum ):
3840 GEMMA_1 = 1
3941 GEMMA_2 = 2
42+ GEMMA_3 = 3
4043
4144
4245@dataclasses .dataclass
@@ -66,7 +69,9 @@ class GemmaConfig:
6669 # Whether a quantized version of the model is used.
6770 quant : bool = False
6871 # The path to the model tokenizer.
69- tokenizer : Optional [str ] = 'tokenizer/tokenizer.model'
72+ tokenizer : Optional [str ] = (
73+ 'tokenizer/tokenizer.model'
74+ )
7075 # The types of attention used in the layers of the model.
7176 attn_types : Optional [Sequence [AttentionType ]] = None
7277 # The size of the sliding window used for local attention.
@@ -82,28 +87,38 @@ class GemmaConfig:
8287 use_pre_ffw_norm : bool = False
8388 # Whether to use post mlp normalization.
8489 use_post_ffw_norm : bool = False
90+ # The wave length of the rotary embedding.
91+ rope_wave_length : dict [AttentionType , int ] | None = None
92+ # Whether to use QK normalization in the attention blocks.
93+ use_qk_norm : bool = False
94+ # Vision model config.
95+ vision_config : siglip_vision_config .SiglipVisionModelConfig | None = None
96+ # The factor by which the rope wave length is divided for global layers.
97+ rope_scaling_factor : int | None = None
8598
8699 def get_dtype (self ) -> Optional [torch .dtype ]:
87100 """Gets the torch dtype from the config dtype string."""
88101 return _STR_DTYPE_TO_TORCH_DTYPE .get (self .dtype , None )
89102
90103
91- def get_config_for_7b () -> GemmaConfig :
92- return GemmaConfig ()
104+ def get_config_for_7b (dtype : str = 'bfloat16' ) -> GemmaConfig :
105+ return GemmaConfig (dtype = dtype )
93106
94107
95- def get_config_for_2b () -> GemmaConfig :
108+ def get_config_for_2b (dtype : str = 'bfloat16' ) -> GemmaConfig :
96109 return GemmaConfig (
110+ dtype = dtype ,
97111 num_hidden_layers = 18 ,
98112 num_attention_heads = 8 ,
99113 num_key_value_heads = 1 ,
100114 hidden_size = 2048 ,
101- intermediate_size = 16384
115+ intermediate_size = 16384 ,
102116 )
103117
104118
105- def get_config_for_2b_v2 () -> GemmaConfig :
119+ def get_config_for_2b_v2 (dtype : str = 'bfloat16' ) -> GemmaConfig :
106120 return GemmaConfig (
121+ dtype = dtype ,
107122 architecture = Architecture .GEMMA_2 ,
108123 num_hidden_layers = 26 ,
109124 num_attention_heads = 8 ,
@@ -120,8 +135,9 @@ def get_config_for_2b_v2() -> GemmaConfig:
120135 )
121136
122137
123- def get_config_for_9b () -> GemmaConfig :
138+ def get_config_for_9b (dtype : str = 'bfloat16' ) -> GemmaConfig :
124139 return GemmaConfig (
140+ dtype = dtype ,
125141 architecture = Architecture .GEMMA_2 ,
126142 num_hidden_layers = 42 ,
127143 num_attention_heads = 16 ,
@@ -138,38 +154,187 @@ def get_config_for_9b() -> GemmaConfig:
138154 )
139155
140156
141- def get_config_for_27b () -> GemmaConfig :
142- return GemmaConfig (
143- architecture = Architecture .GEMMA_2 ,
144- num_hidden_layers = 46 ,
145- num_attention_heads = 32 ,
146- num_key_value_heads = 16 ,
147- hidden_size = 4608 ,
148- intermediate_size = 36864 ,
149- use_pre_ffw_norm = True ,
150- use_post_ffw_norm = True ,
151- final_logit_softcapping = 30.0 ,
152- attn_logit_softcapping = 50.0 ,
153- head_dim = 128 ,
154- attn_types = [AttentionType .LOCAL_SLIDING , AttentionType .GLOBAL ] * 23 ,
155- sliding_window_size = 4096 ,
156- query_pre_attn_scalar = 144 , # hidden_size / num_attention_heads
157- )
157+ def get_config_for_27b (dtype : str = 'bfloat16' ) -> GemmaConfig :
158+ return GemmaConfig (
159+ dtype = dtype ,
160+ architecture = Architecture .GEMMA_2 ,
161+ num_hidden_layers = 46 ,
162+ num_attention_heads = 32 ,
163+ num_key_value_heads = 16 ,
164+ hidden_size = 4608 ,
165+ intermediate_size = 36864 ,
166+ use_pre_ffw_norm = True ,
167+ use_post_ffw_norm = True ,
168+ final_logit_softcapping = 30.0 ,
169+ attn_logit_softcapping = 50.0 ,
170+ head_dim = 128 ,
171+ attn_types = [AttentionType .LOCAL_SLIDING , AttentionType .GLOBAL ] * 23 ,
172+ sliding_window_size = 4096 ,
173+ query_pre_attn_scalar = 144 , # hidden_size / num_attention_heads
174+ )
158175
159176
160- def get_model_config (variant : str ) -> GemmaConfig :
161- if variant == '7b' :
162- return get_config_for_7b ()
163- elif variant == '2b' :
164- return get_config_for_2b ()
165- elif variant == '2b-v2' :
166- return get_config_for_2b_v2 ()
167- elif variant == '9b' :
168- return get_config_for_9b ()
169- elif variant == '27b' :
170- return get_config_for_27b ()
171- else :
172- raise ValueError (
173- f'Invalid variant { variant } . Supported variants are "2b"'
174- 'and "7b" and "9b" and "27b".' )
177+ def get_config_for_1b (dtype : str ) -> GemmaConfig :
178+ return GemmaConfig (
179+ dtype = dtype ,
180+ architecture = Architecture .GEMMA_3 ,
181+ num_hidden_layers = 26 ,
182+ num_attention_heads = 4 ,
183+ num_key_value_heads = 1 ,
184+ hidden_size = 1152 ,
185+ intermediate_size = 6912 ,
186+ use_pre_ffw_norm = True ,
187+ use_post_ffw_norm = True ,
188+ head_dim = 256 ,
189+ attn_types = (
190+ AttentionType .LOCAL_SLIDING ,
191+ AttentionType .LOCAL_SLIDING ,
192+ AttentionType .LOCAL_SLIDING ,
193+ AttentionType .LOCAL_SLIDING ,
194+ AttentionType .LOCAL_SLIDING ,
195+ AttentionType .GLOBAL ,
196+ ),
197+ sliding_window_size = 512 ,
198+ rope_wave_length = {
199+ AttentionType .LOCAL_SLIDING : 10_000 ,
200+ AttentionType .GLOBAL : 1_000_000 ,
201+ },
202+ vocab_size = 262_144 ,
203+ max_position_embeddings = 32_768 ,
204+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
205+ use_qk_norm = True ,
206+ vision_config = None ,
207+ )
175208
209+
210+ def get_config_for_4b (dtype : str ) -> GemmaConfig :
211+ return GemmaConfig (
212+ dtype = dtype ,
213+ architecture = Architecture .GEMMA_3 ,
214+ num_hidden_layers = 34 ,
215+ num_attention_heads = 8 ,
216+ num_key_value_heads = 4 ,
217+ hidden_size = 2560 ,
218+ intermediate_size = 10240 ,
219+ use_pre_ffw_norm = True ,
220+ use_post_ffw_norm = True ,
221+ head_dim = 256 ,
222+ attn_types = (
223+ AttentionType .LOCAL_SLIDING ,
224+ AttentionType .LOCAL_SLIDING ,
225+ AttentionType .LOCAL_SLIDING ,
226+ AttentionType .LOCAL_SLIDING ,
227+ AttentionType .LOCAL_SLIDING ,
228+ AttentionType .GLOBAL ,
229+ ),
230+ sliding_window_size = 1024 ,
231+ rope_wave_length = {
232+ AttentionType .LOCAL_SLIDING : 10_000 ,
233+ AttentionType .GLOBAL : 1_000_000 ,
234+ },
235+ vocab_size = 262_144 ,
236+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
237+ use_qk_norm = True ,
238+ vision_config = siglip_vision_config .get_siglip_vision_model_config (),
239+ rope_scaling_factor = 8 ,
240+ )
241+
242+
243+ def get_config_for_12b (dtype : str ) -> GemmaConfig :
244+ return GemmaConfig (
245+ dtype = dtype ,
246+ architecture = Architecture .GEMMA_3 ,
247+ num_hidden_layers = 48 ,
248+ num_attention_heads = 16 ,
249+ num_key_value_heads = 8 ,
250+ hidden_size = 3840 ,
251+ intermediate_size = 3840 * 8 // 2 ,
252+ use_pre_ffw_norm = True ,
253+ use_post_ffw_norm = True ,
254+ head_dim = 256 ,
255+ attn_types = (
256+ AttentionType .LOCAL_SLIDING ,
257+ AttentionType .LOCAL_SLIDING ,
258+ AttentionType .LOCAL_SLIDING ,
259+ AttentionType .LOCAL_SLIDING ,
260+ AttentionType .LOCAL_SLIDING ,
261+ AttentionType .GLOBAL ,
262+ ),
263+ sliding_window_size = 1024 ,
264+ rope_wave_length = {
265+ AttentionType .LOCAL_SLIDING : 10_000 ,
266+ AttentionType .GLOBAL : 1_000_000 ,
267+ },
268+ vocab_size = 262_144 ,
269+ max_position_embeddings = 131_072 ,
270+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
271+ use_qk_norm = True ,
272+ vision_config = siglip_vision_config .get_siglip_vision_model_config (),
273+ rope_scaling_factor = 8 ,
274+ )
275+
276+
277+ def get_config_for_27b_v3 (dtype : str ) -> GemmaConfig :
278+ return GemmaConfig (
279+ dtype = dtype ,
280+ architecture = Architecture .GEMMA_3 ,
281+ num_hidden_layers = 62 ,
282+ num_attention_heads = 32 ,
283+ num_key_value_heads = 16 ,
284+ hidden_size = 5376 ,
285+ intermediate_size = 5376 * 8 // 2 ,
286+ use_pre_ffw_norm = True ,
287+ use_post_ffw_norm = True ,
288+ head_dim = 128 ,
289+ query_pre_attn_scalar = 5376 // 32 ,
290+ attn_types = (
291+ AttentionType .LOCAL_SLIDING ,
292+ AttentionType .LOCAL_SLIDING ,
293+ AttentionType .LOCAL_SLIDING ,
294+ AttentionType .LOCAL_SLIDING ,
295+ AttentionType .LOCAL_SLIDING ,
296+ AttentionType .GLOBAL ,
297+ ),
298+ sliding_window_size = 1024 ,
299+ rope_wave_length = {
300+ AttentionType .LOCAL_SLIDING : 10_000 ,
301+ AttentionType .GLOBAL : 1_000_000 ,
302+ },
303+ vocab_size = 262_144 ,
304+ max_position_embeddings = 131_072 ,
305+ tokenizer = 'tokenizer/gemma3_cleaned_262144_v2.spiece.model' ,
306+ use_qk_norm = True ,
307+ vision_config = siglip_vision_config .get_siglip_vision_model_config (),
308+ rope_scaling_factor = 8 ,
309+ )
310+
311+
312+ def get_model_config (variant : str , dtype : str = 'bfloat16' ) -> GemmaConfig :
313+ """Gets the GemmaConfig for the diresired variant and dtype."""
314+ # Gemma1 variants
315+ if variant == '7b' :
316+ return get_config_for_7b (dtype )
317+ elif variant == '2b' :
318+ return get_config_for_2b (dtype )
319+ # Gemma2 variants
320+ elif variant == '2b-v2' :
321+ return get_config_for_2b_v2 (dtype )
322+ elif variant == '9b' :
323+ return get_config_for_9b (dtype )
324+ elif variant == '27b' :
325+ return get_config_for_27b (dtype )
326+ # Gemma3 variants
327+ elif variant == '1b' :
328+ return get_config_for_1b (dtype )
329+ elif variant == '4b' :
330+ return get_config_for_4b (dtype )
331+ elif variant == '12b' :
332+ return get_config_for_12b (dtype )
333+ elif variant == '27b_v3' :
334+ return get_config_for_27b_v3 (dtype )
335+ # Invalid variants
336+ else :
337+ raise ValueError (
338+ f'Invalid variant { variant } . Supported variants are "1b", "2b", '
339+ '"2b-v2", "4b",, "7b", "9b" "12b", "27b", and "27b_v3".'
340+ )
0 commit comments