-
Notifications
You must be signed in to change notification settings - Fork 462
added example for bidirectional checkpoint testing #1540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e3e1be8
to
f5f9f14
Compare
logger.info(f"Loading chkpt at: {checkpoint_path}") | ||
load_from_hf = False | ||
for filename in os.listdir(checkpoint_path): | ||
if filename == "model.safetensors.index.json": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not reliable -- if there is only one .safetensors file, there won't be such index file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename it to README.md (so it's displayed when entering this folder)
### Sanity Check (Greedy Decode) | ||
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used. | ||
|
||
Note that the model definitions can be influenced by external factors than correctness of weight conversion. For example, using our verified `convert_to_hf.py` script then running greedy decoding using HF `transformers` without a correct `config.json` will result in a **false negative** since our weights are correct but the model definition is incorrect due to `config.json`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too obscure than it needs to be. Don't need say what could go wrong, say what they need to do to get right.
E.g. you can just say in order to use HF transformers
model, one needs to feed a correct config.json
. Remove the "false negative" part.
## Methods | ||
|
||
### Sanity Check (Greedy Decode) | ||
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test_generate
is for llama 3 only. In general we don't have such thing, and shouldn't rely on them anyways. So if you still would like to have it here, let's explicitly say something like "it's only available for llama 3, but the methodology is general".
### Comprehensive Check (KL Divergence) | ||
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step. | ||
|
||
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. We additionally compare the conversions done with no permutation to double check that our permutation results in a lower kl divergence loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to provide some context on why this permutation is needed in the first place. Otherwise people will get confused why you mention it at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite good. Had some final comments.
state_dict.pop(k, None) | ||
|
||
# Checkpoint Loading | ||
logger.info(f"Loading chkpt at: {checkpoint_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.info(f"Loading chkpt at: {checkpoint_path}") | |
logger.info(f"Loading checkpoint at: {checkpoint_path}") |
hf_model_path = "outputs/checkpoint/step-0-tohf" | ||
hf_model_path_no_perm = "outputs/checkpoint/step-0-tohfnoperm" | ||
|
||
# tt params | ||
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml" | ||
baseline_checkpoint_path = "outputs/checkpoint/step-0-fromllama" | ||
checkpoint_path = "outputs/checkpoint/step-0-fromhf" | ||
checkpoint_path_no_perm = "outputs/checkpoint/step-0-fromhfnoperm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add comment: what are these checkpoints and how they are generated / downloaded?
The point is -- for any one working on a new model, they know what to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now reads very well!
### Comprehensive Check (KL Divergence) | ||
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step. | ||
|
||
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows: | |
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's name it numerical_tests_example.py
This pr adds
checkpoint_conversion.md
to describe our methodology.