Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to reproduce QAT results from Blog #2310

Open
AbhinavDutta opened this issue Jan 29, 2025 · 9 comments
Open

Unable to reproduce QAT results from Blog #2310

AbhinavDutta opened this issue Jan 29, 2025 · 9 comments

Comments

@AbhinavDutta
Copy link

I've been trying to reproduce the Llama3-8b QAT numbers from the blog but have been unable to do so. The training curve looks pretty bad (indicating no training is happening at all) and the evluations are off as well (wandb logs)). Can you let me know if I'm missing something in the configs ?

output_dir: /mnt/azure_data/training/ckpts/ # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /mnt/azure_data/huggingface/llama3/Meta-Llama-3-8B/original/tokenizer.model
  max_seq_len: 8192

# Dataset
dataset:
  _component_: torchtune.datasets.text_completion_dataset
  packed: False  # True increases speed
  source: json
  data_files: [/mnt/azure_data/datasets/c4/en/c4-train.00000-of-01024.json]
seed: null
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir:  /mnt/azure_data/huggingface/llama3/Meta-Llama-3-8B/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 4
epochs: 1
fake_quant_after_n_steps: 1000
clip_grad_norm: inf
# QAT arguments
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 5000
gradient_accumulation_steps: 1  # Use to increase effective batch size
compile: False  # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True  # True reduces memory
enable_activation_offloading: False  # True reduces memory
#custom_sharded_layers: ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.

# Reduced precision
dtype: bf16

These are the evals I get :-

Model ARC-easy(acc_norm) ARC-challenge(acc_norm) MMLU
Baseline 77.74 53.24 65.24
W4A8 QAT 74.74 49.40 54.92
PTQ 75.37 49.40 56.08
@AbhinavDutta
Copy link
Author

Hi,
Any updates on this ? As things are, I'm inclined to believe that what I'm observing is not some "buggy" behaviour but that QAT is performing just as bad as other people have observed in their own personal experiments ?

@ebsmothers
Copy link
Contributor

Hi @AbhinavDutta thanks for opening the issue. Sorry, somehow this one slipped through the cracks. Let me take a look now

@ebsmothers
Copy link
Contributor

@AbhinavDutta I haven't forgotten about this.. still debugging. I confirmed that the loss curves for Llama3 8B QAT have not changed since the recipe first landed. I am also gonna roll back torchao versions and run with Llama2 to figure out if/when there was a regression. Will keep you posted

@AbhinavDutta
Copy link
Author

AbhinavDutta commented Feb 1, 2025

@ebsmothers much appreciated!
Also, can you tell me which commit of torchtune and what version of torchao I need to use to get the "correct" results (as reported in the blog) ?

@AbhinavDutta
Copy link
Author

AbhinavDutta commented Feb 3, 2025

@ebsmothers I faced the same issue using torchtune version 0.2.1, torch version 2.4.1, torchao version 0.3.1 Looking at the dates, this seems closest to the dates of the blog. But I'm facing similar issues there as well.

However, on using the alpaca dataset, I observed that the loss curve does look like normal (but that could be because Llama 3 8B was not instruction tuned to begin with ? UPDATE- it's not, I observe the same loss curve even when using Llama 3 8B-Instruct). I only did that because the QAT config files used this dataset by default. However, looking at the blog details it looks like c4 was used there for QAT and I'm guessing that using the following must have been appropriate in that case :-

# Dataset
dataset:
  _component_: torchtune.datasets.text_completion_dataset
  packed: False  # True increases speed
  source: json
  data_files: [/mnt/azure_data/datasets/c4/en/c4-train.00000-of-01024.json]

So, can you confirm if you ever observed decent loss curves when training with torchtune.datasets.text_completion_dataset ? If so, can you tell me the steps to reproduce?
Alternatively, can you share the yaml file that was used to get the numbers reported in the blog?

@ebsmothers
Copy link
Contributor

@AbhinavDutta we are still looking into this. I have run QAT on Llama2 and Llama3 all the way back to when the QAT PR first landed and I see the same loss curves (i.e. there has not been a regression since the original PR landed). I've confirmed with @andrewor14 that the original results are based on the instruct model, not the base model, and your config looks correct to me. Regarding the loss curves, I will let Andrew weigh in, as he conducted most of the experiments and so will have the most useful information here.

@AbhinavDutta
Copy link
Author

@ebsmothers just to make sure I understand what's going on, could you clarify on the following regarding the loss curves you are talking about (the ones that did not regress) - (i) does it decrease and then flatten as usual ? (ii) was the base model used or instruct ? (iii) was c4 used or alpaca ?

@andrewor14
Copy link
Contributor

Hi @AbhinavDutta, by the way have you tried just running the exact same workload without QAT? A few months ago I found that finetuning Llama3-8B on C4 (even without QAT) doesn't seem to converge anymore: #1526.

To answer your questions:
(i) We expect the loss curve to decrease and flatten, but this doesn't seem to happen for this workload even without QAT
(ii) We used the instruct model in all our Llama3 experiments
(iii) We used C4, but also saw good results with alpaca

@AbhinavDutta
Copy link
Author

@andrewor14 thanks for letting me know! I'll see if the instruct version works out.

I'm guessing without QAT the training won't go anywhere as well (the config I used did not do any quantization for the first 1000 steps and nothing good happened in that regime either)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants