Skip to content

added example for bidirectional checkpoint testing #1540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

wesleytruong
Copy link
Contributor

This pr adds

  • an example script for bidirectional testing of checkpoint conversion scripts
  • a checkpoint_conversion.md to describe our methodology.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 6, 2025
logger.info(f"Loading chkpt at: {checkpoint_path}")
load_from_hf = False
for filename in os.listdir(checkpoint_path):
if filename == "model.safetensors.index.json":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not reliable -- if there is only one .safetensors file, there won't be such index file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename it to README.md (so it's displayed when entering this folder)

### Sanity Check (Greedy Decode)
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used.

Note that the model definitions can be influenced by external factors than correctness of weight conversion. For example, using our verified `convert_to_hf.py` script then running greedy decoding using HF `transformers` without a correct `config.json` will result in a **false negative** since our weights are correct but the model definition is incorrect due to `config.json`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too obscure than it needs to be. Don't need say what could go wrong, say what they need to do to get right.
E.g. you can just say in order to use HF transformers model, one needs to feed a correct config.json. Remove the "false negative" part.

## Methods

### Sanity Check (Greedy Decode)
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For greedy decoding, the `generation/test_generate.py` script can be used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test_generate is for llama 3 only. In general we don't have such thing, and shouldn't rely on them anyways. So if you still would like to have it here, let's explicitly say something like "it's only available for llama 3, but the methodology is general".

### Comprehensive Check (KL Divergence)
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step.

In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. We additionally compare the conversions done with no permutation to double check that our permutation results in a lower kl divergence loss.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to provide some context on why this permutation is needed in the first place. Otherwise people will get confused why you mention it at all.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite good. Had some final comments.

state_dict.pop(k, None)

# Checkpoint Loading
logger.info(f"Loading chkpt at: {checkpoint_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info(f"Loading chkpt at: {checkpoint_path}")
logger.info(f"Loading checkpoint at: {checkpoint_path}")

Comment on lines +111 to +118
hf_model_path = "outputs/checkpoint/step-0-tohf"
hf_model_path_no_perm = "outputs/checkpoint/step-0-tohfnoperm"

# tt params
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml"
baseline_checkpoint_path = "outputs/checkpoint/step-0-fromllama"
checkpoint_path = "outputs/checkpoint/step-0-fromhf"
checkpoint_path_no_perm = "outputs/checkpoint/step-0-fromhfnoperm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add comment: what are these checkpoints and how they are generated / downloaded?
The point is -- for any one working on a new model, they know what to do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now reads very well!

### Comprehensive Check (KL Divergence)
To ensure comprehensive end-to-end correctness we recommend using KL divergence loss to compare the logits between forward passes of both the original and converted model definitions. KL divergence quantifies the "difference" between two probability distributions. A result of zero or a very low KL divergence indicates that the model definitions are equivalent. This method is crucial as it evaluates the entire probability distribution, not just the highest probability at each step.

In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a kl divergence test can reveal subtle inaccuracies such as this, we additionally compare the kl divergence between the original and converted model with and without the permutation. The results are as follows:
In our `./scripts/checkpoint_conversion/example.py` this will be performing forward on dcp checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in huggingface `AutoModelForCausalLM`. To convert Llama3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's name it numerical_tests_example.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants