Open
Description
I've been trying to reproduce the Llama3-8b QAT numbers from the blog but have been unable to do so. The training curve looks pretty bad (indicating no training is happening at all) and the evluations are off as well (wandb logs)). Can you let me know if I'm missing something in the configs ?
output_dir: /mnt/azure_data/training/ckpts/ # /tmp may be deleted by your system. Change it to your preference.
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /mnt/azure_data/huggingface/llama3/Meta-Llama-3-8B/original/tokenizer.model
max_seq_len: 8192
# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
packed: False # True increases speed
source: json
data_files: [/mnt/azure_data/datasets/c4/en/c4-train.00000-of-01024.json]
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /mnt/azure_data/huggingface/llama3/Meta-Llama-3-8B/original/
checkpoint_files: [
consolidated.00.pth
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 4
epochs: 1
fake_quant_after_n_steps: 1000
clip_grad_norm: inf
# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 5000
gradient_accumulation_steps: 1 # Use to increase effective batch size
compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
device: cuda
# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
#custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
# Reduced precision
dtype: bf16
These are the evals I get :-
Model | ARC-easy(acc_norm) | ARC-challenge(acc_norm) | MMLU |
---|---|---|---|
Baseline | 77.74 | 53.24 | 65.24 |
W4A8 QAT | 74.74 | 49.40 | 54.92 |
PTQ | 75.37 | 49.40 | 56.08 |
Metadata
Metadata
Assignees
Labels
No labels