Skip to content

Commit cb7c015

Browse files
Add Gemma3
1 parent 3294a89 commit cb7c015

21 files changed

+1862
-231
lines changed

gemma/config.py

+204-39
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import dataclasses
1818
import enum
19+
import os
1920
import torch
2021
from typing import Optional, Sequence
22+
from .siglip_vision import config as siglip_vision_config
2123

2224

2325
# Keep a mapping from dtype strings to the supported torch dtypes.
@@ -37,6 +39,7 @@ class AttentionType(enum.Enum):
3739
class Architecture(enum.Enum):
3840
GEMMA_1 = 1
3941
GEMMA_2 = 2
42+
GEMMA_3 = 3
4043

4144

4245
@dataclasses.dataclass
@@ -66,7 +69,9 @@ class GemmaConfig:
6669
# Whether a quantized version of the model is used.
6770
quant: bool = False
6871
# The path to the model tokenizer.
69-
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'
72+
tokenizer: Optional[str] = (
73+
'tokenizer/tokenizer.model'
74+
)
7075
# The types of attention used in the layers of the model.
7176
attn_types: Optional[Sequence[AttentionType]] = None
7277
# The size of the sliding window used for local attention.
@@ -82,28 +87,38 @@ class GemmaConfig:
8287
use_pre_ffw_norm: bool = False
8388
# Whether to use post mlp normalization.
8489
use_post_ffw_norm: bool = False
90+
# The wave length of the rotary embedding.
91+
rope_wave_length: dict[AttentionType, int] | None = None
92+
# Whether to use QK normalization in the attention blocks.
93+
use_qk_norm: bool = False
94+
# Vision model config.
95+
vision_config: siglip_vision_config.SiglipVisionModelConfig | None = None
96+
# The factor by which the rope wave length is divided for global layers.
97+
rope_scaling_factor: int| None = None
8598

8699
def get_dtype(self) -> Optional[torch.dtype]:
87100
"""Gets the torch dtype from the config dtype string."""
88101
return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)
89102

90103

91-
def get_config_for_7b() -> GemmaConfig:
92-
return GemmaConfig()
104+
def get_config_for_7b(dtype: str = 'bfloat16') -> GemmaConfig:
105+
return GemmaConfig(dtype=dtype)
93106

94107

95-
def get_config_for_2b() -> GemmaConfig:
108+
def get_config_for_2b(dtype: str = 'bfloat16') -> GemmaConfig:
96109
return GemmaConfig(
110+
dtype=dtype,
97111
num_hidden_layers=18,
98112
num_attention_heads=8,
99113
num_key_value_heads=1,
100114
hidden_size=2048,
101-
intermediate_size=16384
115+
intermediate_size=16384,
102116
)
103117

104118

105-
def get_config_for_2b_v2() -> GemmaConfig:
119+
def get_config_for_2b_v2(dtype: str = 'bfloat16') -> GemmaConfig:
106120
return GemmaConfig(
121+
dtype=dtype,
107122
architecture=Architecture.GEMMA_2,
108123
num_hidden_layers=26,
109124
num_attention_heads=8,
@@ -120,8 +135,9 @@ def get_config_for_2b_v2() -> GemmaConfig:
120135
)
121136

122137

123-
def get_config_for_9b() -> GemmaConfig:
138+
def get_config_for_9b(dtype: str = 'bfloat16') -> GemmaConfig:
124139
return GemmaConfig(
140+
dtype=dtype,
125141
architecture=Architecture.GEMMA_2,
126142
num_hidden_layers=42,
127143
num_attention_heads=16,
@@ -138,38 +154,187 @@ def get_config_for_9b() -> GemmaConfig:
138154
)
139155

140156

141-
def get_config_for_27b() -> GemmaConfig:
142-
return GemmaConfig(
143-
architecture=Architecture.GEMMA_2,
144-
num_hidden_layers=46,
145-
num_attention_heads=32,
146-
num_key_value_heads=16,
147-
hidden_size=4608,
148-
intermediate_size=36864,
149-
use_pre_ffw_norm=True,
150-
use_post_ffw_norm=True,
151-
final_logit_softcapping=30.0,
152-
attn_logit_softcapping=50.0,
153-
head_dim=128,
154-
attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 23,
155-
sliding_window_size=4096,
156-
query_pre_attn_scalar=144, # hidden_size / num_attention_heads
157-
)
157+
def get_config_for_27b(dtype: str = 'bfloat16') -> GemmaConfig:
158+
return GemmaConfig(
159+
dtype=dtype,
160+
architecture=Architecture.GEMMA_2,
161+
num_hidden_layers=46,
162+
num_attention_heads=32,
163+
num_key_value_heads=16,
164+
hidden_size=4608,
165+
intermediate_size=36864,
166+
use_pre_ffw_norm=True,
167+
use_post_ffw_norm=True,
168+
final_logit_softcapping=30.0,
169+
attn_logit_softcapping=50.0,
170+
head_dim=128,
171+
attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 23,
172+
sliding_window_size=4096,
173+
query_pre_attn_scalar=144, # hidden_size / num_attention_heads
174+
)
158175

159176

160-
def get_model_config(variant: str) -> GemmaConfig:
161-
if variant == '7b':
162-
return get_config_for_7b()
163-
elif variant == '2b':
164-
return get_config_for_2b()
165-
elif variant == '2b-v2':
166-
return get_config_for_2b_v2()
167-
elif variant == '9b':
168-
return get_config_for_9b()
169-
elif variant == '27b':
170-
return get_config_for_27b()
171-
else:
172-
raise ValueError(
173-
f'Invalid variant {variant}. Supported variants are "2b"'
174-
'and "7b" and "9b" and "27b".')
177+
def get_config_for_1b(dtype: str) -> GemmaConfig:
178+
return GemmaConfig(
179+
dtype=dtype,
180+
architecture=Architecture.GEMMA_3,
181+
num_hidden_layers=26,
182+
num_attention_heads=4,
183+
num_key_value_heads=1,
184+
hidden_size=1152,
185+
intermediate_size=6912,
186+
use_pre_ffw_norm=True,
187+
use_post_ffw_norm=True,
188+
head_dim=256,
189+
attn_types=(
190+
AttentionType.LOCAL_SLIDING,
191+
AttentionType.LOCAL_SLIDING,
192+
AttentionType.LOCAL_SLIDING,
193+
AttentionType.LOCAL_SLIDING,
194+
AttentionType.LOCAL_SLIDING,
195+
AttentionType.GLOBAL,
196+
),
197+
sliding_window_size=512,
198+
rope_wave_length={
199+
AttentionType.LOCAL_SLIDING: 10_000,
200+
AttentionType.GLOBAL: 1_000_000,
201+
},
202+
vocab_size=262_144,
203+
max_position_embeddings=32_768,
204+
tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
205+
use_qk_norm=True,
206+
vision_config=None,
207+
)
175208

209+
210+
def get_config_for_4b(dtype: str) -> GemmaConfig:
211+
return GemmaConfig(
212+
dtype=dtype,
213+
architecture=Architecture.GEMMA_3,
214+
num_hidden_layers=34,
215+
num_attention_heads=8,
216+
num_key_value_heads=4,
217+
hidden_size=2560,
218+
intermediate_size=10240,
219+
use_pre_ffw_norm=True,
220+
use_post_ffw_norm=True,
221+
head_dim=256,
222+
attn_types=(
223+
AttentionType.LOCAL_SLIDING,
224+
AttentionType.LOCAL_SLIDING,
225+
AttentionType.LOCAL_SLIDING,
226+
AttentionType.LOCAL_SLIDING,
227+
AttentionType.LOCAL_SLIDING,
228+
AttentionType.GLOBAL,
229+
),
230+
sliding_window_size=1024,
231+
rope_wave_length={
232+
AttentionType.LOCAL_SLIDING: 10_000,
233+
AttentionType.GLOBAL: 1_000_000,
234+
},
235+
vocab_size=262_144,
236+
tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
237+
use_qk_norm=True,
238+
vision_config=siglip_vision_config.get_siglip_vision_model_config(),
239+
rope_scaling_factor=8,
240+
)
241+
242+
243+
def get_config_for_12b(dtype: str) -> GemmaConfig:
244+
return GemmaConfig(
245+
dtype=dtype,
246+
architecture=Architecture.GEMMA_3,
247+
num_hidden_layers=48,
248+
num_attention_heads=16,
249+
num_key_value_heads=8,
250+
hidden_size=3840,
251+
intermediate_size=3840 * 8 // 2,
252+
use_pre_ffw_norm=True,
253+
use_post_ffw_norm=True,
254+
head_dim=256,
255+
attn_types=(
256+
AttentionType.LOCAL_SLIDING,
257+
AttentionType.LOCAL_SLIDING,
258+
AttentionType.LOCAL_SLIDING,
259+
AttentionType.LOCAL_SLIDING,
260+
AttentionType.LOCAL_SLIDING,
261+
AttentionType.GLOBAL,
262+
),
263+
sliding_window_size=1024,
264+
rope_wave_length={
265+
AttentionType.LOCAL_SLIDING: 10_000,
266+
AttentionType.GLOBAL: 1_000_000,
267+
},
268+
vocab_size=262_144,
269+
max_position_embeddings=131_072,
270+
tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
271+
use_qk_norm=True,
272+
vision_config=siglip_vision_config.get_siglip_vision_model_config(),
273+
rope_scaling_factor=8,
274+
)
275+
276+
277+
def get_config_for_27b_v3(dtype: str) -> GemmaConfig:
278+
return GemmaConfig(
279+
dtype=dtype,
280+
architecture=Architecture.GEMMA_3,
281+
num_hidden_layers=62,
282+
num_attention_heads=32,
283+
num_key_value_heads=16,
284+
hidden_size=5376,
285+
intermediate_size=5376 * 8 // 2,
286+
use_pre_ffw_norm=True,
287+
use_post_ffw_norm=True,
288+
head_dim=128,
289+
query_pre_attn_scalar=5376 // 32,
290+
attn_types=(
291+
AttentionType.LOCAL_SLIDING,
292+
AttentionType.LOCAL_SLIDING,
293+
AttentionType.LOCAL_SLIDING,
294+
AttentionType.LOCAL_SLIDING,
295+
AttentionType.LOCAL_SLIDING,
296+
AttentionType.GLOBAL,
297+
),
298+
sliding_window_size=1024,
299+
rope_wave_length={
300+
AttentionType.LOCAL_SLIDING: 10_000,
301+
AttentionType.GLOBAL: 1_000_000,
302+
},
303+
vocab_size=262_144,
304+
max_position_embeddings=131_072,
305+
tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
306+
use_qk_norm=True,
307+
vision_config=siglip_vision_config.get_siglip_vision_model_config(),
308+
rope_scaling_factor=8,
309+
)
310+
311+
312+
def get_model_config(variant: str, dtype: str = 'bfloat16') -> GemmaConfig:
313+
"""Gets the GemmaConfig for the diresired variant and dtype."""
314+
# Gemma1 variants
315+
if variant == '7b':
316+
return get_config_for_7b(dtype)
317+
elif variant == '2b':
318+
return get_config_for_2b(dtype)
319+
# Gemma2 variants
320+
elif variant == '2b-v2':
321+
return get_config_for_2b_v2(dtype)
322+
elif variant == '9b':
323+
return get_config_for_9b(dtype)
324+
elif variant == '27b':
325+
return get_config_for_27b(dtype)
326+
# Gemma3 variants
327+
elif variant == '1b':
328+
return get_config_for_1b(dtype)
329+
elif variant == '4b':
330+
return get_config_for_4b(dtype)
331+
elif variant == '12b':
332+
return get_config_for_12b(dtype)
333+
elif variant == '27b_v3':
334+
return get_config_for_27b_v3(dtype)
335+
# Invalid variants
336+
else:
337+
raise ValueError(
338+
f'Invalid variant {variant}. Supported variants are "1b", "2b", '
339+
'"2b-v2", "4b",, "7b", "9b" "12b", "27b", and "27b_v3".'
340+
)

0 commit comments

Comments
 (0)