Skip to content

Commit bd7ebb1

Browse files
Michael Gschwindmalfet
Michael Gschwind
authored andcommitted
reorg repo
1 parent 3c4bbe4 commit bd7ebb1

8 files changed

+287
-0
lines changed

Diff for: CODE_OF_CONDUCT.md

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Code of Conduct
2+
3+
## Our Pledge
4+
5+
In the interest of fostering an open and welcoming environment, we as
6+
contributors and maintainers pledge to make participation in our project and
7+
our community a harassment-free experience for everyone, regardless of age, body
8+
size, disability, ethnicity, sex characteristics, gender identity and expression,
9+
level of experience, education, socio-economic status, nationality, personal
10+
appearance, race, religion, or sexual identity and orientation.
11+
12+
## Our Standards
13+
14+
Examples of behavior that contributes to creating a positive environment
15+
include:
16+
17+
* Using welcoming and inclusive language
18+
* Being respectful of differing viewpoints and experiences
19+
* Gracefully accepting constructive criticism
20+
* Focusing on what is best for the community
21+
* Showing empathy towards other community members
22+
23+
Examples of unacceptable behavior by participants include:
24+
25+
* The use of sexualized language or imagery and unwelcome sexual attention or
26+
advances
27+
* Trolling, insulting/derogatory comments, and personal or political attacks
28+
* Public or private harassment
29+
* Publishing others' private information, such as a physical or electronic
30+
address, without explicit permission
31+
* Other conduct which could reasonably be considered inappropriate in a
32+
professional setting
33+
34+
## Our Responsibilities
35+
36+
Project maintainers are responsible for clarifying the standards of acceptable
37+
behavior and are expected to take appropriate and fair corrective action in
38+
response to any instances of unacceptable behavior.
39+
40+
Project maintainers have the right and responsibility to remove, edit, or
41+
reject comments, commits, code, wiki edits, issues, and other contributions
42+
that are not aligned to this Code of Conduct, or to ban temporarily or
43+
permanently any contributor for other behaviors that they deem inappropriate,
44+
threatening, offensive, or harmful.
45+
46+
## Scope
47+
48+
This Code of Conduct applies within all project spaces, and it also applies when
49+
an individual is representing the project or its community in public spaces.
50+
Examples of representing a project or community include using an official
51+
project e-mail address, posting via an official social media account, or acting
52+
as an appointed representative at an online or offline event. Representation of
53+
a project may be further defined and clarified by project maintainers.
54+
55+
## Enforcement
56+
57+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
58+
reported by contacting the project team at <[email protected]>. All
59+
complaints will be reviewed and investigated and will result in a response that
60+
is deemed necessary and appropriate to the circumstances. The project team is
61+
obligated to maintain confidentiality with regard to the reporter of an incident.
62+
Further details of specific enforcement policies may be posted separately.
63+
64+
Project maintainers who do not follow or enforce the Code of Conduct in good
65+
faith may face temporary or permanent repercussions as determined by other
66+
members of the project's leadership.
67+
68+
## Attribution
69+
70+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72+
73+
[homepage]: https://www.contributor-covenant.org
74+
75+
For answers to common questions about this code of conduct, see
76+
https://www.contributor-covenant.org/faq

Diff for: CONTRIBUTING.md

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Contributing to gpt-fast
2+
We want to make contributing to this project as easy and transparent as
3+
possible.
4+
5+
6+
## Pull Requests
7+
We actively welcome your pull requests.
8+
9+
1. Fork the repo and create your branch from `main`.
10+
2. If you've added code that should be tested, add tests.
11+
3. If you've changed APIs, update the documentation.
12+
4. Ensure the test suite passes.
13+
5. Make sure your code lints.
14+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
15+
16+
## Contributor License Agreement ("CLA")
17+
In order to accept your pull request, we need you to submit a CLA. You only need
18+
to do this once to work on any of Meta's open source projects.
19+
20+
Complete your CLA here: <https://code.facebook.com/cla>
21+
22+
## Issues
23+
We use GitHub issues to track public bugs. Please ensure your description is
24+
clear and has sufficient instructions to be able to reproduce the issue.
25+
26+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
27+
disclosure of security bugs. In those cases, please go through the process
28+
outlined on that page and do not file a public issue.
29+
30+
## License
31+
By contributing to `gpt-fast`, you agree that your contributions will be licensed
32+
under the LICENSE file in the root directory of this source tree.

Diff for: README.md

+32
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,38 @@ Please copy-paste and fork as you desire.
2626

2727
# Supported Models
2828
The model definition (and much more!) is adopted from gpt-fast, so we support the same models.
29+
30+
## Installation
31+
[Download PyTorch nightly](https://pytorch.org/get-started/locally/)
32+
Install sentencepiece and huggingface_hub
33+
```bash
34+
pip install sentencepiece huggingface_hub
35+
```
36+
37+
To download llama models, go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access.
38+
Then login with `huggingface-cli login`
39+
40+
## Downloading Weights
41+
Models tested/supported
42+
```text
43+
tinyllamas/stories{15,42,110}
44+
openlm-research/open_llama_7b
45+
meta-llama/Llama-2-7b-chat-hf
46+
meta-llama/Llama-2-13b-chat-hf
47+
meta-llama/Llama-2-70b-chat-hf
48+
codellama/CodeLlama-7b-Python-hf
49+
codellama/CodeLlama-34b-Python-hf
50+
mistralai/Mistral-7B-v0.1
51+
mistralai/Mistral-7B-Instruct-v0.1
52+
mistralai/Mistral-7B-Instruct-v0.2
53+
```
54+
55+
For example, to convert Llama-2-7b-chat-hf
56+
```bash
57+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
58+
./scripts/prepare.sh $MODEL_REPO
59+
```
60+
2961
See [`gpt-fast` Supported Models](https://github.com/pytorch-labs/gpt-fast?tab=readme-ov-file#supported-models) for a full list.
3062

3163
# Installation

Diff for: requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
sentencepiece

Diff for: scripts/convert_hf_checkpoint.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import json
7+
import re
8+
import sys
9+
from pathlib import Path
10+
from typing import Optional
11+
12+
import torch
13+
14+
# support running without installing as a package
15+
wd = Path(__file__).parent.parent.resolve()
16+
sys.path.append(str(wd))
17+
18+
from model import ModelArgs
19+
20+
21+
@torch.inference_mode()
22+
def convert_hf_checkpoint(
23+
*,
24+
checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"),
25+
model_name: Optional[str] = None,
26+
) -> None:
27+
if model_name is None:
28+
model_name = checkpoint_dir.name
29+
30+
config = ModelArgs.from_name(model_name)
31+
print(f"Model config {config.__dict__}")
32+
33+
# Load the json file containing weight mapping
34+
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
35+
36+
assert model_map_json.is_file()
37+
38+
with open(model_map_json) as json_map:
39+
bin_index = json.load(json_map)
40+
41+
weight_map = {
42+
"model.embed_tokens.weight": "tok_embeddings.weight",
43+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
44+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
45+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
46+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
47+
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
48+
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
49+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
50+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
51+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
52+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
53+
"model.norm.weight": "norm.weight",
54+
"lm_head.weight": "output.weight",
55+
}
56+
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
57+
58+
def permute(w, n_head):
59+
dim = config.dim
60+
return (
61+
w.view(n_head, 2, config.head_dim // 2, dim)
62+
.transpose(1, 2)
63+
.reshape(config.head_dim * n_head, dim)
64+
)
65+
66+
merged_result = {}
67+
for file in sorted(bin_files):
68+
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
69+
merged_result.update(state_dict)
70+
final_result = {}
71+
for key, value in merged_result.items():
72+
if "layers" in key:
73+
abstract_key = re.sub(r'(\d+)', '{}', key)
74+
layer_num = re.search(r'\d+', key).group(0)
75+
new_key = weight_map[abstract_key]
76+
if new_key is None:
77+
continue
78+
new_key = new_key.format(layer_num)
79+
else:
80+
new_key = weight_map[key]
81+
82+
final_result[new_key] = value
83+
84+
for key in tuple(final_result.keys()):
85+
if "wq" in key:
86+
q = final_result[key]
87+
k = final_result[key.replace("wq", "wk")]
88+
v = final_result[key.replace("wq", "wv")]
89+
q = permute(q, config.n_head)
90+
k = permute(k, config.n_local_heads)
91+
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
92+
del final_result[key]
93+
del final_result[key.replace("wq", "wk")]
94+
del final_result[key.replace("wq", "wv")]
95+
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
96+
torch.save(final_result, checkpoint_dir / "model.pth")
97+
98+
if __name__ == '__main__':
99+
import argparse
100+
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
101+
parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"))
102+
parser.add_argument('--model_name', type=str, default=None)
103+
104+
args = parser.parse_args()
105+
convert_hf_checkpoint(
106+
checkpoint_dir=args.checkpoint_dir,
107+
model_name=args.model_name,
108+
)

Diff for: scripts/download.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import os
7+
from typing import Optional
8+
9+
from requests.exceptions import HTTPError
10+
11+
12+
def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
13+
from huggingface_hub import snapshot_download
14+
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
15+
try:
16+
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
17+
except HTTPError as e:
18+
if e.response.status_code == 401:
19+
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
20+
else:
21+
raise e
22+
23+
if __name__ == '__main__':
24+
import argparse
25+
parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.')
26+
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.')
27+
parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
28+
29+
args = parser.parse_args()
30+
hf_download(args.repo_id, args.hf_token)

Diff for: scripts/prepare.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1

Diff for: scripts/test_flow.sh

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
2+
rm -r checkpoints/$MODEL_REPO
3+
python scripts/download.py --repo_id $MODEL_REPO
4+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
5+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth
6+
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --max_new_tokens 100

0 commit comments

Comments
 (0)