Skip to content

added example for bidirectional checkpoint testing #1540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions scripts/checkpoint_conversion/README.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now reads very well!

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Testing Checkpoint Conversion for Correctness

When converting checkpoints between file types or model definitions, we need to ensure that the converted checkpoints are correct, i.e. their model definition remains the same, which includes that the converted checkpoint's weights will give the same outputs when loaded in the new intended program context.

This guide provides a general framework on how to test your conversion script for correctness. The example that we will use here is bidirectional conversion between HuggingFace and `torchtitan`.

## Methods

### Sanity Check (Greedy Decode)
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For Llama3, greedy decoding can be achieved using the `generation/test_generate.py` script. Other models may not have an inference script, but the methodology holds the same.

Note that your model definition needs to match your conversion script. For example, if converting from `torchtitan` to HuggingFace, be sure to include the correct `config.json` file that matches the `torchtitan` model architecture. Providing an incorrect `config.json` when loading the model with HuggingFace `transformers` will result in incorrect generations despite a correct weight conversion.

### Comprehensive Check (KL Divergence)
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on DCP checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in HuggingFace `AutoModelForCausalLM`. This script tests the HuggingFace -> `torchtitan` direction, as loading a HuggingFace checkpoint requires both
- converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it, and
- converting the HF version of state dict back to torchtitan using `from_hf`.

To convert Llama 3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows:
```
$ python ./scripts/checkpoint_conversion/example.py
Average loss of test from_hf is -1.45365707318601e-13
Average loss of test from_hf_no_perm is 5.368335223465692e-06
```
162 changes: 162 additions & 0 deletions scripts/checkpoint_conversion/numerical_tests_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

import torch.distributed.checkpoint as dcp
import torch.nn.functional as F
from torchtitan.components.checkpoint import excluded_parameters_for_model_only
from torchtitan.config import ConfigManager
from torchtitan.protocols.train_spec import get_train_spec
from torchtitan.tools.logging import logger
from transformers import AutoModelForCausalLM

device_type = "cuda" if torch.cuda.is_available() else "cpu"


def loss_fn(logits1, logits2):
# Convert logits to probabilities
probs1 = F.log_softmax(logits1, dim=-1)
probs2 = F.softmax(logits2, dim=-1)

# Calculate KL Divergence
kl_loss = F.kl_div(probs1, probs2, "mean")
return kl_loss


@torch.no_grad
def forward_hf(model_name, model_path: Optional[str], input_ids):
# Load the tokenizer and model
model_path = model_path if model_path else model_name
model = AutoModelForCausalLM.from_pretrained(model_path)

device = torch.device(device_type)
model.to(device)

# List to store outputs
outputs_list = []

for inputs in input_ids:
inputs = inputs.to(device)
outputs = model.generate(
inputs=inputs,
max_length=prompt_len + 1,
do_sample=False,
output_logits=True,
return_dict_in_generate=True,
)

outputs = torch.stack(outputs.logits)
outputs_list.append(outputs)

del model
torch.cuda.empty_cache()

return outputs_list


@torch.no_grad
def forward_tt(config_path, checkpoint_path, test_set):

config_manager = ConfigManager()
config = config_manager.parse_args([f"--job.config_file={config_path}"])

train_spec = get_train_spec(config.model.name)

model_args = train_spec.model_args[config.model.flavor]
model_args.update_from_config(config)

model = train_spec.model_cls(model_args)

# materalize model
device = torch.device(device_type)
model.to_empty(device=device)
with torch.no_grad():
model.init_weights()
model.eval()

state_dict = model.state_dict()
for k in excluded_parameters_for_model_only:
state_dict.pop(k, None)

# Checkpoint Loading
logger.info(f"Loading checkpoint at: {checkpoint_path}")
dcp.load(state_dict, checkpoint_id=checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait I thought you'd need
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L437

maybe in this case it's not needed because the pointer to params didn't change?
cc @fegin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is not wrapped with ModelWrapper and the logic doesn't call distributed state_dict either, so technically, load_state_dict() is not required. dcp.load() will perform the inplace update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wesleytruong wait I'm more confused.
In convert_from_hf.py you used ModelWrapper for loading HF and saving.
Here you call state_dict = model.state_dict() without ModelWrapper -- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??

Copy link
Contributor Author

@wesleytruong wesleytruong Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wesleytruong wait I'm more confused. In convert_from_hf.py you used ModelWrapper for loading HF and saving. Here you call state_dict = model.state_dict() without ModelWrapper -- why a checkpoint from the wrapped state dict can be loaded into the non-wrapped state dict? Are they interchangeable??

Sorry for the confusion, as Chien-Chin said, they end up with the same result since it's not distributed. From what I understand under the hood ModelWrapper calls PTD's get_model_state_dict which handles getting the state dict of a sharded model, and dcp.load handles loading state dict from a sharded checkpoint. DCP can go from unsharded/sharded checkpoint to unsharded state dict, and both model.state_dict() and ModelWrapper's get_model_state_dict are full state dicts if model is unsharded, so that's why this works.

Either way, I should either change both to follow ModelWrapper or model for consistency. Which one would you prefer?
@tianyu-l


output_list = []
for prompt in test_set:
input_ids = prompt.to(device_type)
# ensure batch dimension (T,) --> (B, T)
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)

# obtains the logits of only the last token in the predictions
predictions = model(input_ids)[:, -1, :].unsqueeze(1)
output_list.append(predictions)

del model
torch.cuda.empty_cache()

return output_list


if __name__ == "__main__":
# hf params
hf_model_name = "meta-llama/Meta-Llama-3-8B"

# tt params
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml"
checkpoint_path = "outputs/test_checkpoint/step-0-fromhf" # dcp checkpoint from convert_from_hf.py
# dcp checkpoint from convert_from_hf.py without using sd_adapter's permute
checkpoint_path_no_perm = "outputs/test_checkpoint/step-0-fromhfnoperm"

# test params
prompt_len = 8
test_size = 100

config_manager = ConfigManager()
config = config_manager.parse_args([f"--job.config_file={config_path}"])
train_spec = get_train_spec(config.model.name)
tokenizer = train_spec.build_tokenizer_fn(config)

# Build test set of randomly generated token ids
test_set = [
torch.randint(
0,
tokenizer.get_vocab_size(),
(
1, # batch size
prompt_len,
),
)
for _ in range(test_size)
]

# baseline logits
baseline_hf_outputs = forward_hf(hf_model_name, None, test_set)

# testing from hf conversion
from_hf_outputs = forward_tt(config_path, checkpoint_path, test_set)
from_hf_outputs_no_perm = forward_tt(config_path, checkpoint_path_no_perm, test_set)

# Define the set of outputs to test loss for
test_configs = {
"from_hf": [baseline_hf_outputs, from_hf_outputs],
"from_hf_no_perm": [baseline_hf_outputs, from_hf_outputs_no_perm],
}
avg_losses = {}

for test_name, (baseline_outputs, conversion_outputs) in test_configs.items():
total_loss = 0
for baseline, outputs in zip(baseline_outputs, conversion_outputs):
total_loss += loss_fn(baseline, outputs)
avg_loss = total_loss / len(test_set)
avg_losses[test_name] = avg_loss.item()

for test_name, avg_loss in avg_losses.items():
print(f"Average loss for test {test_name} is {avg_loss}")
Loading