This repository contains the official code for the paper “Distilling to Hybrid Attention Models via KL-Guided Layer Selection.”
It also includes a reimplementation of RADLADS: Rapid Attention Distillation to Linear Attention Decoders at Scale (arXiv:2505.03005), which we use as the backbone of our distillation pipeline.
git clone git@github.com:fla-org/hybrid-distillation.git
cd hybrid-distillation
conda create -n your_env_name python=3.12
conda activate your_env_name
pip install -r requirements.txt
pip install flash-attn==2.7.4.post1 --no-build-isolationSet your enviromental variable:
export HF_TOKEN="YOUR_HF_TOKEN"
export HF_HOME="YOUR_HF_HOME"
export WANDB_API_KEY="YOUR_WANDB_API_KEY"
export TRITON_CACHE_DIR=...
export WANDB_ENTITY="YOUR_WANDB_ENTITY"
export WANDB_PROJECT="YOUR_WANDB_PROJECT"
- Download + tokenize a Hugging Face dataset and save it to disk (
save_to_disk).
python preprocess_download_tokenize.py \
--dataset_name <hf_dataset_name_or_path> \
--split train \
--text_field text \
--tokenizer fla-hub/Qwen2.5-3B-Instruct \
--output_dir data_cache/tokenized_dataset- Chunk the concatenated token stream into fixed-length blocks for each stage.
# for stage1
python preprocess_chunk.py \
--tokenized_dataset_path data_cache/tokenized_dataset \
--context_length 512 \
--output_dir data_cache
# for stage2
python preprocess_chunk.py \
--tokenized_dataset_path data_cache/tokenized_dataset \
--context_length 4096 \
--output_dir data_cacheTraining has two stages, driven by YAML configs under configs/. You typically run Stage 1 once per (teacher, student variant), then reuse its checkpoint for any number of Stage 2 runs (different ratios / selectors / layer sets).
Stage 1 aligns the student’s linear-attention outputs. It is always run with an all-linear student:
keep_full_attention_layers: [].
Example (Qwen2.5-3B teacher, gdn_v4 student):
deepspeed train.py --cfg configs/qwen2_3b_gdn_v4_hybrid_0_125_uniform/stage1.yamlAfter Stage 1 finishes, the HF-format checkpoint is saved to:
<train.output_dir>/converted-hf/
Default Stage-1 settings in the provided configs:
- Target tokens: 100M
- Sequence length: 512
- Peak LR: 1e-3
- Scheduler: cosine
- Batch size: 96
Stage 2 distills logits from the teacher into a hybrid student. This is where you specify the softmax layers via
student_model.keep_full_attention_layers.
deepspeed train.py --cfg configs/qwen2_3b_gdn_v4_hybrid_0_125_uniform/stage2.yamlBefore launching Stage 2:
-
Set
train.student_init_ckptto the Stage-1 checkpoint:<stage1_output_dir>/converted-hf/
-
Set
student_model.keep_full_attention_layers: [...]for your hybrid.
Recommended Stage-2 settings:
- Target tokens: 600M
- Sequence length: 4096
All configs live under configs/. A config folder usually contains a stage2.yaml for a specific hybrid run, and may also include a stage1.yaml for convenience (but you do not need to run Stage 1 for every folder).
For a fixed pair:
teacher_model.namestudent_model.name
you run Stage 1 once and reuse the resulting:
<stage1_output_dir>/converted-hf/
Then any Stage-2 config for that same pair should point to it via:
train.student_init_ckpt: <stage1_output_dir>/converted-hf/
Config directories typically follow:
{teacher_family}_{size}_{student_variant}_hybrid_{ratio}_{selector}/
Examples:
configs/qwen2_3b_gdn_v4_hybrid_0_125_uniform/configs/llama3_3b_gdn_v4_hybrid_0_25_ga_s2/
-
data.cache_dir- Stage 1: usually
data_cache/chunked_context512 - Stage 2: usually
data_cache/chunked_context4096
- Stage 1: usually
-
student_model.keep_full_attention_layers(Stage 2): the chosen softmax layers -
train.output_dir: output path for checkpoints/logs -
train.student_init_ckpt(Stage 2): shared Stage-1converted-hf/ -
train.target_tokens,train.train_seq_len, batch sizes, etc.
- Choose (or create) a canonical Stage-1 run for your (teacher, student) pair.
- Run Stage 1 once and keep
<output_dir>/converted-hf/. - Create new Stage-2 folders for different ratios / selectors / layer sets.
- Point all of them to the same
train.student_init_ckpt.
This part includes scripts to reproduce KL-guided layer selection (our GA-S2 selector) and several baseline layer-importance heuristics.
Our main selector measures each layer’s marginal utility by restoring exactly one layer to full (softmax) attention in an otherwise linear-attention student, running Stage 2 distillation, then ranking layers by how much the distillation objective improves (we use the logged training loss / KL proxy).
First run Stage 1 to align the student’s linear-attention outputs. This produces a Stage-1 checkpoint, which you then convert to a unified HF-style checkpoint (the converted-hf folder described above in the Stage 1 section).
You should end up with something like:
checkpoints/<exp_name>/stage1/converted-hf/
Use generate_layer_configs_ga_s2.py to generate N Stage-2 YAML configs (one per layer). Each config sets:
keep_full_attention_layers: [idx]for that layer indexoutput_diruniquely per layer (important for W&B retrieval)
Example:
python layer_selection/generate_layer_configs_ga_s2.py \
--num-layers 28 \
--config-root config/qwen2_7b_gdn_v4_hybrid_layer_selection \
--teacher fla-hub/Qwen2.5-7B-Instruct \
--student gdn_v4 \
--cache-stage2 data_cache/chunked_context4096 \
--outputs-prefix checkpoints/qwen2_7b_gdn_v4_hybrid_layer_selection_{index}Important: generate_layer_configs_ga_s2.py contains a STAGE2_TEMPLATE with a student_init_ckpt: field.
Before launching runs, make sure it points to your Stage-1 converted checkpoint, e.g.
student_init_ckpt: 'checkpoints/<your_stage1_exp>/stage1/converted-hf'
Also note: the template uses target_tokens: 600_000_000, but for ranking runs, ~200M tokens is typically enough. You can edit target_tokens in the template (or in the generated YAMLs) to speed up selection.
Run Stage 2 once per layer (these runs are independent and can be launched in parallel on a cluster):
for i in $(seq 0 27); do
deepspeed train.py --cfg config/qwen2_7b_gdn_v4_hybrid_layer_selection/layer_${i}/stage2.yaml
doneEach run logs to W&B, and the run display name is expected to match output_dir (see retrieval script below).
After all runs finish, fetch the logged loss curves from W&B:
- Edit these constants in
layer_selection/retrieve_loss_log_from_wandb.py:
ENTITYPROJECT- (optionally)
OUTPUT_FILE - (optionally) the
RUN_NAMEpattern if youroutput_dirdiffers
- Then run:
python layer_selection/retrieve_loss_log_from_wandb.pyThis produces a JSON file mapping each run name to sampled (train/global_step, train/loss) points.
Note: the retrieval script assumes:
- runs are
finished - the W&B run
display_nameequals the trainingoutput_dirIf your W&B naming differs, update the script accordingly.
Point FILE_TO_LOAD in layer_selection/get_ranking_from_wandb_loss_log.py to the JSON produced above, then run:
python layer_selection/get_ranking_from_wandb_loss_log.pyThis prints layer rankings (best → worst) at each logged step. In practice, you can take the ranking at a later step (or aggregate across steps) and then select the top-K layers as your final keep_full_attention_layers.
Create a final Stage-2 config (or modify an existing one) with:
student_model:
keep_full_attention_layers: [l1, l2, ..., lK]Then run the standard distillation pipeline for your final hybrid model (Stage 2 only since you can reuse Stage-1 checkpoint).
For baseline layer-importance heuristics, we provide evaluate_layer_importance.py, which:
- evaluates a model on a chosen synthetic task,
- ablates each layer by zeroing its attention output (via a forward hook),
- reports the performance drop per layer (larger drop ⇒ more “important”).
Supported tasks (--task):
synthetic_retrievalassociative_recallassociative_recall_mutihopvariable_trackingcommon_words_extractionfrequent_words_extractionphone_book
Example (Variable Tracking):
python layer_selection/evaluate_layer_importance.py \
--task variable_tracking \
--model_name fla-hub/Qwen2.5-3B-Instruct \
--output_dir layer_selection_baselines \
--num_samples 500 \
--batch_size 8 \
--num_chains 2 \
--num_hops 2 \
--noise_ratio 1.0Artifacts saved under --output_dir:
results/baseline_*.json(baseline accuracy)results/ablation_*.csv(layer-wise accuracy + drop)plots/ablation_*.png(visualization)
Evaluation is performed using the lm-evaluation-harness. To install:
git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness
cd lm-evaluation-harness
pip install -e .
pip install "lm_eval[ruler]" # ruler-specific installTo evaluate a checkpoint:
bash eval.sh /path/to/your/checkpoint
Important: Use the HF-converted checkpoint format, which is automatically saved to the converted-hf directory inside your output_dir.
We use the triton-implemented linear attention kernels from fla-org/flash-linear-attention. We refer to HazyResearch/lolcats to construct our training process. The evaluation is supported by lm-evaluation-harness. Thank you for these excellent open-source efforts.
TBD