Skip to content

Commit ef055fc

Browse files
committed
Added gemma support
1 parent dd62281 commit ef055fc

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def from_name(cls, name: str):
5050

5151

5252
transformer_configs = {
53+
"gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384),
5354
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
5455
"7B": dict(n_layer=32, n_head=32, dim=4096),
5556
"13B": dict(n_layer=40, n_head=40, dim=5120),
@@ -109,6 +110,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
109110
mask = self.causal_mask[None, None, input_pos]
110111
freqs_cis = self.freqs_cis[input_pos]
111112
x = self.tok_embeddings(idx)
113+
x = (self.config.dim ** 0.5) * x
112114

113115
for i, layer in enumerate(self.layers):
114116
x = layer(x, input_pos, freqs_cis, mask)
@@ -195,7 +197,7 @@ def __init__(self, config: ModelArgs) -> None:
195197
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
196198

197199
def forward(self, x: Tensor) -> Tensor:
198-
return self.w2(F.silu(self.w1(x)) * self.w3(x))
200+
return self.w2(F.gelu(self.w1(x)) * self.w3(x))
199201

200202

201203
class RMSNorm(nn.Module):
@@ -209,7 +211,7 @@ def _norm(self, x):
209211

210212
def forward(self, x: Tensor) -> Tensor:
211213
output = self._norm(x.float()).type_as(x)
212-
return output * self.weight
214+
return output * (1 + self.weight)
213215

214216

215217
def precompute_freqs_cis(

scripts/convert_hf_checkpoint.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def convert_hf_checkpoint(
3030
config = ModelArgs.from_name(model_name)
3131
print(f"Model config {config.__dict__}")
3232

33+
from safetensors import safe_open
34+
3335
# Load the json file containing weight mapping
34-
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
36+
model_map_json = checkpoint_dir / "model.safetensors.index.json"
3537

3638
assert model_map_json.is_file()
3739

@@ -65,7 +67,8 @@ def permute(w, n_head):
6567

6668
merged_result = {}
6769
for file in sorted(bin_files):
68-
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
70+
state_dict = safe_open(str(file), framework="pt", device='cpu')
71+
state_dict = {k: state_dict.get_tensor(k) for k in state_dict.keys()}
6972
merged_result.update(state_dict)
7073
final_result = {}
7174
for key, value in merged_result.items():
@@ -92,6 +95,9 @@ def permute(w, n_head):
9295
del final_result[key]
9396
del final_result[key.replace("wq", "wk")]
9497
del final_result[key.replace("wq", "wv")]
98+
if "output.weight" not in final_result:
99+
final_result["output.weight"] = final_result["tok_embeddings.weight"]
100+
95101
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
96102
torch.save(final_result, checkpoint_dir / "model.pth")
97103

0 commit comments

Comments
 (0)