Skip to content

Commit a3e825f

Browse files
authored
Merge pull request #138 from malfet/malfet/add-support-for-tinystories
By just defining a configs and using model weights storied in model Should enable TinyStories LLMS posted in https://huggingface.co/karpathy/tinyllamas Test Plan: `python generate.py --checkpoint_path checkpoints/stories15M/stories15M.pt --prompt "Once upon a time"`
2 parents f479b07 + 11ce176 commit a3e825f

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

generate.py

+2
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def _load_model(checkpoint_path, device, precision, use_tp):
235235
model = simple_quantizer.convert_for_runtime(use_cuda)
236236

237237
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
238+
if "model" in checkpoint and "stories" in str(checkpoint_path):
239+
checkpoint = checkpoint["model"]
238240
model.load_state_dict(checkpoint, assign=True)
239241

240242
if use_tp:

model.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def from_name(cls, name: str):
6363
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
6464
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
6565
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
66+
"stories15M": dict(n_layer=6, n_head=6, dim=288),
67+
"stories110M": dict(n_layer=12, n_head=12, dim=768),
6668
}
6769

6870
class KVCache(nn.Module):

0 commit comments

Comments
 (0)