Skip to content

Conversation

@casteryh
Copy link
Contributor

@casteryh casteryh commented Nov 5, 2025

Summary

This PR adds a new online distillation app that demonstrates how a smaller student model (Qwen3-1.7B) can learn from a larger teacher model (Qwen3-32B) via KL divergence on generated completions.

This is an example of online distillation where:

  • Student model generates rollouts from prompts
  • Teacher model (frozen) provides target logprobs for the same completions
  • Training objective: Minimize KL divergence between student and teacher distributions
  • No rewards or advantages - pure distillation objective

Key Components

Architecture:

  • StudentGenerator (vLLM): Generates completions from prompts using student model
  • StudentTrainer: Trains student model to match teacher distributions
  • TeacherModel (frozen): Provides target logprobs for distillation
  • DatasetActor: Provides prompts (GSM8K)
  • ReplayBuffer: Batches episodes for training

Loss Function:

def distillation_loss(student_logprobs, teacher_logprobs, padding_mask):
    # Forward KL: KL(teacher || student)
    kl = teacher_logprobs.exp() * (teacher_logprobs - student_logprobs)
    return kl.mean()

Implementation Details

Based on apps/grpo/main.py with key differences:

  • ✅ Removed RewardActor and ComputeAdvantages (not needed for distillation)
  • ✅ Replaced simple_grpo_loss with distillation_loss (pure KL divergence)
  • ✅ Simplified Episode dataclass (removed reward and advantage fields)
  • ✅ Renamed policystudent_generator, ref_modelteacher_model for clarity
  • ✅ Updated collate function to work without advantages

Files Added

  • apps/distillation/main.py: Main training loop
  • apps/distillation/qwen3_distillation.yaml: Config for Qwen3-1.7B → Qwen3-32B distillation

Usage

python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml

Test Plan

  • ✅ Verified Python syntax with py_compile
  • ✅ Config follows same pattern as GRPO configs
  • ✅ All pre-commit hooks passed

cc @wukaixingxp

This adds a new online distillation app that demonstrates how to use a smaller student model to learn from a larger teacher model via KL divergence on generated completions.

Key components:
- Student model (Qwen3-1.7B): Generates rollouts and gets trained
- Teacher model (Qwen3-32B): Frozen, provides target logprobs for distillation
- Loss: Pure KL divergence between student and teacher distributions
- No rewards or advantages - direct distillation objective

Implementation based on apps/grpo/main.py with key differences:
- Removed RewardActor and ComputeAdvantages
- Replaced GRPO loss with distillation_loss (KL divergence)
- Simplified Episode dataclass (no reward/advantage fields)
- Renamed policy → student_generator, ref_model → teacher_model for clarity

Usage:
python -m apps.distillation.main --config apps/distillation/qwen3_distillation.yaml

Test Plan:
- Verified Python syntax with py_compile
- Config follows same pattern as GRPO configs
@casteryh casteryh requested a review from wukaixingxp November 5, 2025 01:06
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 5, 2025
Changed from forward KL to reverse KL to match the standard online distillation objective:
KL(student || teacher) = E_{x~student}[log student - log teacher]

This is the natural choice for online distillation because:
1. We sample from the student policy during rollouts
2. Reverse KL is mode-seeking (student focuses on teacher's high-probability modes)
3. Simpler formula: just student_logprobs - teacher_logprobs
4. More stable gradients

The formula is now:
kl = student_logprobs - teacher_logprobs

instead of the previous forward KL:
kl = teacher_logprobs.exp() * (teacher_logprobs - student_logprobs)
…eward

Changed the loss formulation to:
  reward = -reverse_kl = -(sampled_logprobs - teacher_logprobs)
  loss = -E[importance_weight * reward]

where:
- reverse_kl is DETACHED (no backprop through it)
- importance_weight = exp(logprobs - logprobs.detach())

This treats distillation as a reward-based objective where the "reward" is
how well the student matches the teacher (negative KL). The gradient flows
through the importance sampling term only, not through the KL itself.

This is similar to GRPO's policy gradient term but without the KL penalty:
  per_token_policy_loss = exp(logprobs - logprobs.detach()) * reward
  loss = -per_token_policy_loss

Key difference from previous implementation:
- Before: Direct KL minimization with backprop through both student and teacher logprobs
- Now: REINFORCE-style gradient with detached KL as reward signal
@casteryh casteryh changed the title Add online distillation app [DO NOT REVIEW] online distillation example Nov 5, 2025
@casteryh casteryh changed the title [DO NOT REVIEW] online distillation example [DO NOT REVIEW] on-policy distillation example Nov 5, 2025
@casteryh casteryh changed the title [DO NOT REVIEW] on-policy distillation example [DO NOT REVIEW][NOT FOR LAND] on-policy distillation example Nov 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants