Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions apps/on_policy_distillation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# On-Policy Distillation for Math Reasoning

This app implements on-policy distillation (OPD) following the approach described in the [Thinking Machines blog post](https://thinkingmachines.ai/blog/on-policy-distillation/). OPD combines the benefits of on-policy training with dense reward signals for efficient post-training.

## Overview

On-policy distillation trains a student model by:
1. Sampling trajectories from the student model itself
2. Using a teacher model to grade each token with dense rewards (per-token KL divergence)
3. Training the student to minimize reverse KL with the teacher

This approach is **10-30x more compute efficient** than traditional RL while achieving comparable or better performance.

## Experimental Setup

### Models
- **Student**: Qwen3-0.6B-Base (or Qwen3-8B for larger experiments)
- **Teacher**: Qwen3-8B (or Qwen3-32B)
- **Evaluation**: AIME'24 benchmark

### Training Pipeline

#### Phase 1: Supervised Fine-Tuning (SFT)
First, establish a strong baseline through off-policy distillation:

```bash
python -m apps.sft.main --config apps/sft/qwen3_0_6.yaml
```

- **Dataset**: OpenThoughts3-1.2M (400k prompts)
- **Expected Performance**: ~60% on AIME'24
- **Purpose**: Teaches the model basic math reasoning patterns

#### Phase 2: On-Policy Distillation
Refine the model using on-policy learning with dense supervision:

```bash
python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_opd.yaml
```

- **Starting Point**: SFT checkpoint from Phase 1
- **Dataset**: Math prompts (from OpenThoughts3 or DeepMath, but only prompts - not solutions)
- **Training**: ~150 steps (77k prompts with 4 samples each)
- **Expected Performance**: ~70% on AIME'24

### Key Implementation Details

1. **Loss Function**: Per-token reverse KL divergence
```python
reverse_kl = -(student_logprobs - teacher_logprobs)
```

2. **Sampling**: Generate multiple trajectories per prompt (n=16 in config)

3. **No Discount Factor**: Optimize only immediate next token (discount=0)

4. **Efficient Batching**: Can use smaller batch sizes than RL due to dense rewards

## Evaluation

Evaluate on AIME'24 benchmark after each phase:

```bash
python -m apps.eval.aime --checkpoint <path_to_checkpoint>
```

## Expected Results

| Method | AIME'24 Score | Training Cost |
|--------|---------------|---------------|
| SFT (400k prompts) | ~60% | Baseline |
| SFT (2M prompts, extrapolated) | ~70% | 5x baseline |
| OPD (150 steps) | ~70% | 0.1-0.3x baseline |

## Key Advantages

- **Compute Efficiency**: 10-30x reduction vs traditional RL
- **Dense Supervision**: Learns from every token, not just final rewards
- **Data Efficiency**: Can reuse prompts multiple times effectively
- **Stability**: More stable training than sparse RL rewards

## Notes for Reproduction

1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD
2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions
3. **Teacher quality matters**: Better teachers provide better supervision
4. **Monitor reverse KL**: Should decrease to near-zero as training progresses

## References

- [On-Policy Distillation Blog Post](https://thinkingmachines.ai/blog/on-policy-distillation/)
- [Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook)
- [OpenThoughts3 Dataset](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M)

---

**Important Code Modification Needed**: Your current OPD implementation should:
1. Load from an SFT checkpoint (not raw base model)
2. Extract only prompts from the dataset (not use the solutions)
3. Add proper checkpoint loading in the trainer config:

```yaml
trainer:
checkpoint:
initial_load_path: ./checkpoint_student/sft_final # Load SFT checkpoint
# ... rest of config
```
7 changes: 7 additions & 0 deletions apps/on_policy_distillation/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dataclasses import dataclass


@dataclass
class DatasetConfig:
source: str
split: str = "train"
222 changes: 222 additions & 0 deletions apps/on_policy_distillation/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import asyncio
import itertools
import time
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors.generator import Generator
from forge.actors.reference_model import ReferenceModel
from forge.actors.trainer import RLTrainer
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.util.config import parse
from forge.util.ops import compute_logprobs
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer


@dataclass
class Trajectory:
pad_id: int
request_len: int
response_len: int
# Processed data
completion: Completion | None = None
teacher_logprobs: torch.Tensor | None = None

@property
def request_tensor(self) -> torch.Tensor:
tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
elif tensor.shape[0] > self.request_len: # truncate
tensor = tensor[-self.request_len :]
return tensor

@property
def response_tensor(self) -> torch.Tensor:
tensor: torch.Tensor = self.completion.token_ids.to(torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
elif tensor.shape[0] > self.response_len: # truncate
tensor = tensor[: self.response_len]
return tensor


def collate(
batches: list[list[Trajectory]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
inputs = []
targets = []
for batch in batches:
request = [t.request_tensor for t in batch]
request = torch.stack(request)

response = [t.response_tensor for t in batch]
response = torch.stack(response)

teacher_logprobs = [t.teacher_logprobs for t in batch]
teacher_logprobs = torch.stack(teacher_logprobs)

# student_logprobs = [t.completion.logprobs for t in batch]
# student_logprobs = torch.stack(student_logprobs)

pad_id = batch[0].pad_id
padding_mask = response != pad_id

input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"teacher_logprobs": teacher_logprobs,
# "student_logprobs": student_logprobs,
"padding_mask": padding_mask,
}
inputs.append(input)
targets.append(target)
return inputs, targets


def importance_sampling_loss(
logits: torch.Tensor,
response: torch.Tensor,
teacher_logprobs: torch.Tensor,
# student_logprobs: torch.Tensor,
padding_mask: torch.Tensor,
**kwargs,
) -> torch.Tensor:
student_logprobs = compute_logprobs(logits, response)
reverse_kl = -(student_logprobs - teacher_logprobs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reverse_kl = -(student_logprobs - teacher_logprobs)
reverse_kl = -(student_logprobs.detach() - teacher_logprobs)

prob_ratio = torch.exp(teacher_logprobs - student_logprobs)
per_token_loss = prob_ratio * reverse_kl

# Apply mask and compute mean over valid tokens
masked_loss = per_token_loss * padding_mask
num_valid_tokens = padding_mask.sum(dim=1, keepdim=True).clamp(min=1.0)
loss_per_sequence = masked_loss.sum(dim=1, keepdim=True) / num_valid_tokens
loss = loss_per_sequence.mean()

return loss


async def main(cfg: DictConfig):
train_batch_size = cfg.train_batch_size
max_steps = cfg.trainer.training.steps
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens

provisioner = await init_provisioner()
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(
{
"wandb": {"project": "opd-v0", "logging_mode": "global_reduce"},
"console": {"logging_mode": "global_reduce"},
}
)
student_trainer, student_generator, teacher = await asyncio.gather(
RLTrainer.options(**cfg.services.trainer).as_actor(
**cfg.trainer, loss=importance_sampling_loss
),
Generator.options(**cfg.services.student_generator).as_service(
**cfg.student_generator
),
ReferenceModel.options(**cfg.services.teacher).as_service(**cfg.teacher),
)

# Setup torchstore for weight management
trainer_num_procs = cfg.services.trainer["procs"]
trainer_host_mesh_name = cfg.services.trainer["mesh_name"]
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
)

# Load dataset
tokenizer = get_tokenizer(cfg.student_model)
pad_id = tokenizer.pad_token_id
dataset = load_dataset(cfg.dataset.path, split=cfg.dataset.get("split", "train"))
# dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"])
dataset_iter = iter(dataset)

print("All services initialized successfully!")

step = 0
for epoch in range(max_steps):
# start time
start = time.perf_counter()
if step >= max_steps:
break

trajectories = []
while len(trajectories) < train_batch_size:
try:
sample = next(dataset_iter)
conversation = sample["conversations"]
prompt = conversation[0]["value"]

completions = await student_generator.generate.fanout(prompt)

for completion in itertools.chain(*completions):
# Create trajectory with raw completion
trajectory = Trajectory(
pad_id=pad_id,
request_len=max_req_tokens,
response_len=max_res_tokens,
completion=completion,
)

# Build padded input for teacher using trajectory properties
input_ids = torch.cat(
[
trajectory.request_tensor.unsqueeze(0),
trajectory.response_tensor.unsqueeze(0),
],
dim=1,
)

teacher_logprobs = await teacher.forward.route(
input_ids, max_req_tokens, return_logprobs=True
)

trajectory.teacher_logprobs = teacher_logprobs
trajectories.append(trajectory)
except StopIteration:
print("Dataset exhausted, resetting iterator")
dataset_iter = iter(dataset)

# Train on collected trajectories
trajectories = [
trajectories[i::train_batch_size] for i in range(train_batch_size)
]
inputs, targets = collate(trajectories)
await student_trainer.train_step.call(inputs, targets)

step += 1

# Push weights to student generator
await student_trainer.push_weights.call(step)
await student_generator.update_weights.fanout(step)

end = time.perf_counter()
print(f"Step {step} took {end - start} seconds")

await mlogger.flush.call_one(step)

print(f"Training completed after {step} steps")
await shutdown()


if __name__ == "__main__":

@parse
def _main(cfg):
asyncio.run(main(cfg))

_main()
Loading
Loading