Skip to content

Commit 2a46586

Browse files
committed
Stub
1 parent 15da0db commit 2a46586

File tree

9 files changed

+273
-55
lines changed

9 files changed

+273
-55
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# On-Policy Distillation for Math Reasoning
2+
3+
This app implements on-policy distillation (OPD) following the approach described in the [Thinking Machines blog post](https://thinkingmachines.ai/blog/on-policy-distillation/). OPD combines the benefits of on-policy training with dense reward signals for efficient post-training.
4+
5+
## Overview
6+
7+
On-policy distillation trains a student model by:
8+
1. Sampling trajectories from the student model itself
9+
2. Using a teacher model to grade each token with dense rewards (per-token KL divergence)
10+
3. Training the student to minimize reverse KL with the teacher
11+
12+
This approach is **10-30x more compute efficient** than traditional RL while achieving comparable or better performance.
13+
14+
## Experimental Setup
15+
16+
### Models
17+
- **Student**: Qwen3-0.6B-Base (or Qwen3-8B for larger experiments)
18+
- **Teacher**: Qwen3-8B (or Qwen3-32B)
19+
- **Evaluation**: AIME'24 benchmark
20+
21+
### Training Pipeline
22+
23+
#### Phase 1: Supervised Fine-Tuning (SFT)
24+
First, establish a strong baseline through off-policy distillation:
25+
26+
```bash
27+
python -m apps.sft.main --config apps/sft/qwen3_0_6.yaml
28+
```
29+
30+
- **Dataset**: OpenThoughts3-1.2M (400k prompts)
31+
- **Expected Performance**: ~60% on AIME'24
32+
- **Purpose**: Teaches the model basic math reasoning patterns
33+
34+
#### Phase 2: On-Policy Distillation
35+
Refine the model using on-policy learning with dense supervision:
36+
37+
```bash
38+
python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_opd.yaml
39+
```
40+
41+
- **Starting Point**: SFT checkpoint from Phase 1
42+
- **Dataset**: Math prompts (from OpenThoughts3 or DeepMath, but only prompts - not solutions)
43+
- **Training**: ~150 steps (77k prompts with 4 samples each)
44+
- **Expected Performance**: ~70% on AIME'24
45+
46+
### Key Implementation Details
47+
48+
1. **Loss Function**: Per-token reverse KL divergence
49+
```python
50+
reverse_kl = -(student_logprobs - teacher_logprobs)
51+
```
52+
53+
2. **Sampling**: Generate multiple trajectories per prompt (n=16 in config)
54+
55+
3. **No Discount Factor**: Optimize only immediate next token (discount=0)
56+
57+
4. **Efficient Batching**: Can use smaller batch sizes than RL due to dense rewards
58+
59+
## Evaluation
60+
61+
Evaluate on AIME'24 benchmark after each phase:
62+
63+
```bash
64+
python -m apps.eval.aime --checkpoint <path_to_checkpoint>
65+
```
66+
67+
## Expected Results
68+
69+
| Method | AIME'24 Score | Training Cost |
70+
|--------|---------------|---------------|
71+
| SFT (400k prompts) | ~60% | Baseline |
72+
| SFT (2M prompts, extrapolated) | ~70% | 5x baseline |
73+
| OPD (150 steps) | ~70% | 0.1-0.3x baseline |
74+
75+
## Key Advantages
76+
77+
- **Compute Efficiency**: 10-30x reduction vs traditional RL
78+
- **Dense Supervision**: Learns from every token, not just final rewards
79+
- **Data Efficiency**: Can reuse prompts multiple times effectively
80+
- **Stability**: More stable training than sparse RL rewards
81+
82+
## Notes for Reproduction
83+
84+
1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD
85+
2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions
86+
3. **Teacher quality matters**: Better teachers provide better supervision
87+
4. **Monitor reverse KL**: Should decrease to near-zero as training progresses
88+
89+
## References
90+
91+
- [On-Policy Distillation Blog Post](https://thinkingmachines.ai/blog/on-policy-distillation/)
92+
- [Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook)
93+
- [OpenThoughts3 Dataset](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M)
94+
95+
---
96+
97+
**Important Code Modification Needed**: Your current OPD implementation should:
98+
1. Load from an SFT checkpoint (not raw base model)
99+
2. Extract only prompts from the dataset (not use the solutions)
100+
3. Add proper checkpoint loading in the trainer config:
101+
102+
```yaml
103+
trainer:
104+
checkpoint:
105+
initial_load_path: ./checkpoint_student/sft_final # Load SFT checkpoint
106+
# ... rest of config
107+
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class DatasetConfig:
6+
source: str
7+
split: str = "train"

apps/on-policy-distillation/main.py renamed to apps/on_policy_distillation/main.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
import itertools
3+
import time
24
from dataclasses import dataclass
35
from typing import Any
46

@@ -63,13 +65,17 @@ def collate(
6365
teacher_logprobs = [t.teacher_logprobs for t in batch]
6466
teacher_logprobs = torch.stack(teacher_logprobs)
6567

68+
# student_logprobs = [t.completion.logprobs for t in batch]
69+
# student_logprobs = torch.stack(student_logprobs)
70+
6671
pad_id = batch[0].pad_id
6772
padding_mask = response != pad_id
6873

6974
input = {"tokens": torch.cat([request, response], dim=1)}
7075
target = {
7176
"response": response,
7277
"teacher_logprobs": teacher_logprobs,
78+
# "student_logprobs": student_logprobs,
7379
"padding_mask": padding_mask,
7480
}
7581
inputs.append(input)
@@ -81,6 +87,7 @@ def importance_sampling_loss(
8187
logits: torch.Tensor,
8288
response: torch.Tensor,
8389
teacher_logprobs: torch.Tensor,
90+
# student_logprobs: torch.Tensor,
8491
padding_mask: torch.Tensor,
8592
**kwargs,
8693
) -> torch.Tensor:
@@ -135,32 +142,28 @@ async def main(cfg: DictConfig):
135142
tokenizer = get_tokenizer(cfg.student_model)
136143
pad_id = tokenizer.pad_token_id
137144
dataset = load_dataset(cfg.dataset.path, split=cfg.dataset.get("split", "train"))
138-
dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"])
145+
# dataset = dataset.filter(lambda x: x["domain"] == cfg.dataset["domain"])
139146
dataset_iter = iter(dataset)
140147

141148
print("All services initialized successfully!")
142149

143150
step = 0
144151
for epoch in range(max_steps):
152+
# start time
153+
start = time.perf_counter()
145154
if step >= max_steps:
146155
break
147156

148-
# Collect rollout
149157
trajectories = []
150158
while len(trajectories) < train_batch_size:
151159
try:
152160
sample = next(dataset_iter)
153-
# Extract the human prompt from OpenThoughts format
154-
conversations = sample.get("conversations", [])
155-
if conversations and len(conversations) > 0:
156-
prompt = conversations[0].get("value", "")
157-
else:
158-
prompt = sample.get("prompt", sample.get("text", ""))
161+
conversation = sample["conversations"]
162+
prompt = conversation[0]["value"]
159163

160-
print(f"Starting request with prompt: {prompt}")
161-
completions = await student_generator.generate.route(prompt)
164+
completions = await student_generator.generate.fanout(prompt)
162165

163-
for completion in completions:
166+
for completion in itertools.chain(*completions):
164167
# Create trajectory with raw completion
165168
trajectory = Trajectory(
166169
pad_id=pad_id,
@@ -201,6 +204,9 @@ async def main(cfg: DictConfig):
201204
await student_trainer.push_weights.call(step)
202205
await student_generator.update_weights.fanout(step)
203206

207+
end = time.perf_counter()
208+
print(f"Step {step} took {end - start} seconds")
209+
204210
await mlogger.flush.call_one(step)
205211

206212
print(f"Training completed after {step} steps")

apps/on-policy-distillation/qwen_0_6b_to_8b.yaml renamed to apps/on_policy_distillation/qwen_0_6b_to_8b.yaml

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
# >>> python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_0_6b_to_8b.yaml
33

44
# Global configuration
5-
train_batch_size: 4 # Number of trajectories per training step
6-
max_req_tokens: 512
7-
max_res_tokens: 65536
8-
student_model: "Qwen/Qwen3-1.7B"
5+
train_batch_size: 64 # Number of trajectories per training step
6+
max_req_tokens: 2048
7+
max_res_tokens: 4096
8+
student_model: "Qwen/Qwen3-0.6B"
99
teacher_model: "Qwen/Qwen3-8B"
1010

1111
# Dataset configuration
1212
dataset:
1313
path: "open-thoughts/OpenThoughts3-1.2M"
1414
split: "train"
15-
domain: "math"
1615

1716
# Student Generator configuration (inference model)
1817
student_generator:
@@ -22,7 +21,7 @@ student_generator:
2221
pipeline_parallel_size: 1
2322
enforce_eager: false
2423
sampling_params:
25-
n: 2 # Single response per prompt
24+
n: 16
2625
max_tokens: ${max_res_tokens}
2726
temperature: 1.0
2827
top_p: 0.95
@@ -31,7 +30,7 @@ student_generator:
3130
trainer:
3231
model:
3332
name: qwen3
34-
flavor: 1.7B
33+
flavor: 0.6B
3534
hf_assets_path: hf://${student_model}
3635
optimizer:
3736
name: AdamW
@@ -41,32 +40,32 @@ trainer:
4140
warmup_steps: 10
4241
training:
4342
local_batch_size: ${train_batch_size} # Per-device batch size
44-
seq_len: 66048
43+
seq_len: 8192
4544
max_norm: 1.0
4645
steps: 10000
4746
dtype: bfloat16
48-
gc_freq: 1
47+
gc_freq: 5
4948
compile:
5049
enable: false
5150
parallelism:
5251
data_parallel_replicate_degree: 1
53-
data_parallel_shard_degree: 2
52+
data_parallel_shard_degree: 1
5453
tensor_parallel_degree: 1
5554
pipeline_parallel_degree: 1
5655
context_parallel_degree: 1
5756
expert_parallel_degree: 1
5857
disable_loss_parallel: true
5958
checkpoint:
6059
enable: true
61-
# folder: ./checkpoint_student
60+
folder: ./checkpoint_student
6261
initial_load_path: hf://${student_model}
6362
initial_load_in_hf: true
6463
last_save_in_hf: true
65-
interval: 500
64+
interval: 250
6665
async_mode: "disabled"
6766
activation_checkpoint:
68-
mode: selective
69-
selective_ac_option: op
67+
mode: none
68+
# selective_ac_option: op
7069

7170
# Teacher model configuration
7271
teacher:
@@ -77,13 +76,13 @@ teacher:
7776
training:
7877
seq_len: ${trainer.training.seq_len}
7978
dtype: bfloat16
80-
gc_freq: 1
79+
gc_freq: 10
8180
compile:
8281
enable: false
8382
parallelism:
8483
data_parallel_replicate_degree: 1
85-
data_parallel_shard_degree: 2
86-
tensor_parallel_degree: 1 # Use 2 GPUs for teacher
84+
data_parallel_shard_degree: 1
85+
tensor_parallel_degree: 1
8786
pipeline_parallel_degree: 1
8887
context_parallel_degree: 1
8988
expert_parallel_degree: 1
@@ -95,17 +94,17 @@ teacher:
9594
# Resource allocations (3 GPUs total)
9695
services:
9796
student_generator:
98-
procs: 1 # Student inference: 1 GPU
99-
num_replicas: 1
97+
procs: 1
98+
num_replicas: 4
10099
mesh_name: student_generator
101100
with_gpus: true
102101
teacher:
103-
procs: 2 # Teacher: 2 GPUs with TP
104-
num_replicas: 1
102+
procs: 1
103+
num_replicas: 2
105104
mesh_name: teacher
106105
with_gpus: true
107106
trainer:
108-
procs: 2 # Student training: shares GPU with student_generator
107+
procs: 1
109108
num_replicas: 1
110109
mesh_name: trainer
111110
with_gpus: true

apps/sft/main.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from forge.controller import ForgeActor
2626
from forge.data.collate import collate_packed
2727
from forge.data.datasets.packed import PackedDataset, TextPacker
28-
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
28+
from forge.data.datasets.sft_dataset import (
29+
AlpacaToMessages,
30+
OpenThoughtsToMessages,
31+
sft_iterable_dataset,
32+
)
2933
from forge.data.tokenizer import HuggingFaceModelTokenizer
3034
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
3135
from forge.util.config import parse
@@ -165,13 +169,32 @@ def setup_data(self):
165169
),
166170
)
167171

172+
# Get dataset configuration from job_config
173+
dataset_config = self.job_config["dataset"]
174+
dataset_path = dataset_config["path"]
175+
dataset_split = dataset_config["split"]
176+
message_transform_type = dataset_config.get("message_transform", "alpaca")
177+
masking_strategy = dataset_config.get("masking_strategy", "train_on_assistant")
178+
179+
# Select the appropriate message transform
180+
if message_transform_type == "openthoughts":
181+
message_transform = OpenThoughtsToMessages(
182+
masking_strategy=masking_strategy
183+
)
184+
elif message_transform_type == "alpaca":
185+
message_transform = AlpacaToMessages(masking_strategy=masking_strategy)
186+
else:
187+
raise ValueError(
188+
f"Unknown message_transform type: {message_transform_type}"
189+
)
190+
168191
dataset = sft_iterable_dataset(
169192
model_transform=tokenizer,
170-
message_transform=AlpacaToMessages(),
171-
path="yahma/alpaca-cleaned",
172-
split="train",
193+
message_transform=message_transform,
194+
path=dataset_path,
195+
split=dataset_split,
173196
)
174-
packer = TextPacker(padding_idx=0)
197+
packer = TextPacker(padding_idx=151643)
175198
dataset = PackedDataset(
176199
dataset=dataset,
177200
packer=packer,

src/forge/actors/reference_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
import torch
1616
import torch.nn.functional as F
1717

18-
from forge.controller import ForgeActor
19-
from forge.observability.metrics import record_metric, Reduce
20-
from forge.observability.perf_tracker import Tracer
21-
2218
# from forge.util.ops import compute_logprobs
2319
from monarch.actor import current_rank, current_size, endpoint
2420
from torch.distributed.tensor import DTensor
@@ -34,6 +30,10 @@
3430
from torchtitan.experiments.forge.engine import ForgeEngine
3531
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3632

33+
from forge.controller import ForgeActor
34+
from forge.observability.metrics import record_metric, Reduce
35+
from forge.observability.perf_tracker import Tracer
36+
3737
logger = logging.getLogger(__name__)
3838
logger.setLevel(logging.INFO)
3939

0 commit comments

Comments
 (0)