-
Notifications
You must be signed in to change notification settings - Fork 469
added example for bidirectional checkpoint testing #1540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d8fd427
2f70223
bd39e50
833785b
eeb649c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Testing Checkpoint Conversion for Correctness | ||
|
||
When converting checkpoints between file types or model definitions, we need to ensure that the converted checkpoints are correct, i.e. their model definition remains the same, which includes that the converted checkpoint's weights will give the same outputs when loaded in the new intended program context. | ||
|
||
This guide provides a general framework on how to test your conversion script for correctness. The example that we will use here is bidirectional conversion between HuggingFace and `torchtitan`. | ||
|
||
## Methods | ||
|
||
### Sanity Check (Greedy Decode) | ||
A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For Llama3, greedy decoding can be achieved using the `generation/test_generate.py` script. Other models may not have an inference script, but the methodology holds the same. | ||
|
||
Note that your model definition needs to match your conversion script. For example, if converting from `torchtitan` to HuggingFace, be sure to include the correct `config.json` file that matches the `torchtitan` model architecture. Providing an incorrect `config.json` when loading the model with HuggingFace `transformers` will result in incorrect generations despite a correct weight conversion. | ||
|
||
### Comprehensive Check (KL Divergence) | ||
In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on DCP checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in HuggingFace `AutoModelForCausalLM`. This script tests the HuggingFace -> `torchtitan` direction, as loading a HuggingFace checkpoint requires both | ||
- converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it, and | ||
- converting the HF version of state dict back to torchtitan using `from_hf`. | ||
|
||
To convert Llama 3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: | ||
``` | ||
$ python ./scripts/checkpoint_conversion/example.py | ||
Average loss of test from_hf is -1.45365707318601e-13 | ||
Average loss of test from_hf_no_perm is 5.368335223465692e-06 | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
|
||
import torch.distributed.checkpoint as dcp | ||
import torch.nn.functional as F | ||
from torchtitan.components.checkpoint import excluded_parameters_for_model_only | ||
from torchtitan.config import ConfigManager | ||
from torchtitan.protocols.train_spec import get_train_spec | ||
from torchtitan.tools.logging import logger | ||
from transformers import AutoModelForCausalLM | ||
|
||
device_type = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
def loss_fn(logits1, logits2): | ||
# Convert logits to probabilities | ||
probs1 = F.log_softmax(logits1, dim=-1) | ||
probs2 = F.softmax(logits2, dim=-1) | ||
|
||
# Calculate KL Divergence | ||
kl_loss = F.kl_div(probs1, probs2, "mean") | ||
return kl_loss | ||
|
||
|
||
@torch.no_grad | ||
def forward_hf(model_name, model_path: Optional[str], input_ids): | ||
# Load the tokenizer and model | ||
model_path = model_path if model_path else model_name | ||
model = AutoModelForCausalLM.from_pretrained(model_path) | ||
|
||
device = torch.device(device_type) | ||
model.to(device) | ||
|
||
# List to store outputs | ||
outputs_list = [] | ||
|
||
for inputs in input_ids: | ||
inputs = inputs.to(device) | ||
outputs = model.generate( | ||
inputs=inputs, | ||
max_length=prompt_len + 1, | ||
do_sample=False, | ||
output_logits=True, | ||
return_dict_in_generate=True, | ||
) | ||
|
||
outputs = torch.stack(outputs.logits) | ||
outputs_list.append(outputs) | ||
|
||
del model | ||
torch.cuda.empty_cache() | ||
|
||
return outputs_list | ||
|
||
|
||
@torch.no_grad | ||
def forward_tt(config_path, checkpoint_path, test_set): | ||
|
||
config_manager = ConfigManager() | ||
config = config_manager.parse_args([f"--job.config_file={config_path}"]) | ||
|
||
train_spec = get_train_spec(config.model.name) | ||
|
||
model_args = train_spec.model_args[config.model.flavor] | ||
model_args.update_from_config(config) | ||
|
||
model = train_spec.model_cls(model_args) | ||
|
||
# materalize model | ||
device = torch.device(device_type) | ||
model.to_empty(device=device) | ||
with torch.no_grad(): | ||
model.init_weights() | ||
model.eval() | ||
|
||
state_dict = model.state_dict() | ||
for k in excluded_parameters_for_model_only: | ||
state_dict.pop(k, None) | ||
|
||
# Checkpoint Loading | ||
logger.info(f"Loading checkpoint at: {checkpoint_path}") | ||
dcp.load(state_dict, checkpoint_id=checkpoint_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait I thought you'd need maybe in this case it's not needed because the pointer to params didn't change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This model is not wrapped with ModelWrapper and the logic doesn't call distributed state_dict either, so technically, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wesleytruong wait I'm more confused. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion, as Chien-Chin said, they end up with the same result since it's not distributed. From what I understand under the hood ModelWrapper calls PTD's Either way, I should either change both to follow ModelWrapper or model for consistency. Which one would you prefer? |
||
|
||
output_list = [] | ||
for prompt in test_set: | ||
input_ids = prompt.to(device_type) | ||
# ensure batch dimension (T,) --> (B, T) | ||
if input_ids.ndim == 1: | ||
input_ids = input_ids.unsqueeze(0) | ||
|
||
# obtains the logits of only the last token in the predictions | ||
predictions = model(input_ids)[:, -1, :].unsqueeze(1) | ||
output_list.append(predictions) | ||
|
||
del model | ||
torch.cuda.empty_cache() | ||
|
||
return output_list | ||
|
||
|
||
if __name__ == "__main__": | ||
# hf params | ||
hf_model_name = "meta-llama/Meta-Llama-3-8B" | ||
|
||
# tt params | ||
config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml" | ||
checkpoint_path = "outputs/test_checkpoint/step-0-fromhf" # dcp checkpoint from convert_from_hf.py | ||
# dcp checkpoint from convert_from_hf.py without using sd_adapter's permute | ||
checkpoint_path_no_perm = "outputs/test_checkpoint/step-0-fromhfnoperm" | ||
|
||
# test params | ||
prompt_len = 8 | ||
test_size = 100 | ||
|
||
config_manager = ConfigManager() | ||
config = config_manager.parse_args([f"--job.config_file={config_path}"]) | ||
train_spec = get_train_spec(config.model.name) | ||
tokenizer = train_spec.build_tokenizer_fn(config) | ||
|
||
# Build test set of randomly generated token ids | ||
test_set = [ | ||
torch.randint( | ||
0, | ||
tokenizer.get_vocab_size(), | ||
( | ||
1, # batch size | ||
prompt_len, | ||
), | ||
) | ||
for _ in range(test_size) | ||
] | ||
|
||
# baseline logits | ||
baseline_hf_outputs = forward_hf(hf_model_name, None, test_set) | ||
|
||
# testing from hf conversion | ||
from_hf_outputs = forward_tt(config_path, checkpoint_path, test_set) | ||
from_hf_outputs_no_perm = forward_tt(config_path, checkpoint_path_no_perm, test_set) | ||
|
||
# Define the set of outputs to test loss for | ||
test_configs = { | ||
"from_hf": [baseline_hf_outputs, from_hf_outputs], | ||
"from_hf_no_perm": [baseline_hf_outputs, from_hf_outputs_no_perm], | ||
} | ||
avg_losses = {} | ||
|
||
for test_name, (baseline_outputs, conversion_outputs) in test_configs.items(): | ||
total_loss = 0 | ||
for baseline, outputs in zip(baseline_outputs, conversion_outputs): | ||
total_loss += loss_fn(baseline, outputs) | ||
avg_loss = total_loss / len(test_set) | ||
avg_losses[test_name] = avg_loss.item() | ||
|
||
for test_name, avg_loss in avg_losses.items(): | ||
print(f"Average loss for test {test_name} is {avg_loss}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now reads very well!