Skip to content

Commit 1d341eb

Browse files
author
Allen Wang
committed
Merge branch 'main' into interfaces
2 parents 4d11e4e + bb57589 commit 1d341eb

File tree

17 files changed

+464
-417
lines changed

17 files changed

+464
-417
lines changed

README.md

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,40 +30,23 @@ You can also find our notebook tutorials (coming soon)
3030

3131
## Installation
3232

33-
### Basic
34-
3533
torchforge requires PyTorch 2.9.0 with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan).
3634

37-
You can install Forge with:
38-
```
39-
$ conda create -n forge python=3.10
40-
$ conda activate forge
41-
$ uv pip install .
42-
```
43-
44-
(conda-less uv install is a wip)
45-
46-
For your reference, we also include a basic install script that installs other system dependencies
47-
along with torchforge:
48-
(note that this basic install script
49-
uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.)
35+
Install torchforge with:
5036

5137
```bash
5238
conda create -n forge python=3.12
5339
conda activate forge
5440
./scripts/install.sh
5541
```
5642

57-
Optional: By default, the packages installation uses conda. If user wants to install system packages on the target machine instead of conda, they can pass the `--use-sudo` to the installation script: `./script/install.sh --use-sudo`.
43+
The install script installs system dependencies along with torchforge. Note that this install script uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.
5844

59-
After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices):
45+
Optional: By default, the packages installation uses conda. If you want to install system packages on the target machine instead of conda, you can pass the `--use-sudo` flag to the installation script: `./scripts/install.sh --use-sudo`.
6046

47+
> **Note:** We are actively working on enabling pure `uv` installation. Currently, Conda is the recommended approach. `uv` support is not fully working at the moment but is being tracked in [issue #494](https://github.com/meta-pytorch/torchforge/issues/494).
6148
62-
```
63-
uv run apps/grpo/main.py --config apps/grpo/qwen3_1_7b.yaml
64-
```
65-
66-
or if not using uv:
49+
After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices):
6750

6851
```
6952
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

apps/grpo/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from forge.actors.generator import Generator
2424
from forge.actors.reference_model import ReferenceModel
2525
from forge.actors.replay_buffer import ReplayBuffer
26-
from forge.actors.trainer import RLTrainer
26+
from forge.actors.trainer import TitanTrainer
2727
from forge.controller.actor import ForgeActor
2828
from forge.controller.provisioner import init_provisioner, shutdown
2929
from forge.data.rewards import MathReward, ThinkingReward
@@ -318,7 +318,7 @@ async def main(cfg: DictConfig):
318318
) = await asyncio.gather(
319319
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
320320
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
321-
RLTrainer.options(**cfg.actors.trainer).as_actor(
321+
TitanTrainer.options(**cfg.actors.trainer).as_actor(
322322
**cfg.trainer, loss=simple_grpo_loss
323323
),
324324
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(

docs/source/api_trainer.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
The Trainer manages model training in TorchForge, built on top of TorchTitan.
88
It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows.
99

10-
## RLTrainer
10+
## TitanTrainer
1111

1212
```{eval-rst}
13-
.. autoclass:: RLTrainer
13+
.. autoclass:: TitanTrainer
1414
:members: train_step, push_weights, cleanup
1515
:exclude-members: __init__
1616
```
1717

1818
## Configuration
1919

20-
The RLTrainer uses TorchTitan's configuration system with the following components:
20+
The TitanTrainer uses TorchTitan's configuration system with the following components:
2121

2222
### Job Configuration
2323

docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ graph LR
9696
S3["RewardActor"]
9797
S4["ReferenceModel"]
9898
S5["ReplayBuffer"]
99-
S6["RLTrainer"]
99+
S6["TitanTrainer"]
100100
end
101101
102102
C1 --> S1
@@ -306,7 +306,7 @@ TorchForge handles behind the scenes:
306306
from forge.actors.generator import Generator as Policy
307307
from forge.actors.replay_buffer import ReplayBuffer
308308
from forge.actors.reference_model import ReferenceModel
309-
from forge.actors.trainer import RLTrainer
309+
from forge.actors.trainer import TitanTrainer
310310
from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
311311
from forge.data.rewards import MathReward, ThinkingReward
312312
import asyncio
@@ -348,7 +348,7 @@ group_size = 1
348348
}
349349
),
350350
# Trainer actor with GPU
351-
RLTrainer.options(procs=1, with_gpus=True).as_actor(
351+
TitanTrainer.options(procs=1, with_gpus=True).as_actor(
352352
# Trainer config would come from YAML in real usage
353353
model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": f"hf://{model}"},
354354
optimizer={"name": "AdamW", "lr": 1e-5},
@@ -378,12 +378,12 @@ group_size = 1
378378

379379
TorchForge has two types of distributed components:
380380
- **Services**: Multiple replicas with automatic load balancing (like Policy, RewardActor)
381-
- **Actors**: Single instances that handle their own internal distribution (like RLTrainer, ReplayBuffer)
381+
- **Actors**: Single instances that handle their own internal distribution (like TitanTrainer, ReplayBuffer)
382382

383383
We cover this distinction in detail in Part 2, but for now this explains the scaling patterns:
384384
- Policy service: num_replicas=8 for high inference demand
385385
- RewardActor service: num_replicas=16 for parallel evaluation
386-
- RLTrainer actor: Single instance with internal distributed training
386+
- TitanTrainer actor: Single instance with internal distributed training
387387

388388

389389
### Fault Tolerance

docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ async def simple_rl_step():
470470
if batch is not None:
471471
print("Training on batch...")
472472
inputs, targets = batch # GRPO returns (inputs, targets) tuple
473-
loss = await trainer.train_step.call(inputs, targets) # RLTrainer is an actor
473+
loss = await trainer.train_step.call(inputs, targets) # TitanTrainer is an actor
474474
print(f"Training loss: {loss}")
475475
return loss
476476
else:
@@ -507,7 +507,7 @@ reward_actor = await RewardActor.options(
507507
)
508508

509509
# Training needs fewer but more powerful replicas
510-
trainer = await RLTrainer.options(
510+
trainer = await TitanTrainer.options(
511511
procs=1, with_gpus=True # Fewer but GPU-heavy
512512
).as_actor( # Trainer typically uses .as_actor() not .as_service()
513513
model={"name": "qwen3", "flavor": "1.7B"},
@@ -580,7 +580,7 @@ import torch
580580
from forge.actors.generator import Generator as Policy
581581
from forge.actors.reference_model import ReferenceModel
582582
from forge.actors.replay_buffer import ReplayBuffer
583-
from forge.actors.trainer import RLTrainer
583+
from forge.actors.trainer import TitanTrainer
584584
from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages
585585
from forge.data.rewards import MathReward, ThinkingReward
586586

@@ -603,7 +603,7 @@ print("Initializing all services...")
603603
engine_config={"model": "Qwen/Qwen3-1.7B", "tensor_parallel_size": 1},
604604
sampling_config={"n": 1, "max_tokens": 512}
605605
),
606-
RLTrainer.options(procs=1, with_gpus=True).as_actor(
606+
TitanTrainer.options(procs=1, with_gpus=True).as_actor(
607607
model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": "hf://Qwen/Qwen3-1.7B"},
608608
optimizer={"name": "AdamW", "lr": 1e-5},
609609
training={"local_batch_size": 2, "seq_len": 2048}
@@ -667,7 +667,7 @@ print("Shutting down services...")
667667
await asyncio.gather(
668668
DatasetActor.shutdown(dataloader),
669669
policy.shutdown(),
670-
RLTrainer.shutdown(trainer),
670+
TitanTrainer.shutdown(trainer),
671671
ReplayBuffer.shutdown(replay_buffer),
672672
ComputeAdvantages.shutdown(compute_advantages),
673673
ReferenceModel.shutdown(ref_model),

src/forge/actors/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
8+
79
__all__ = [
810
"Generator",
9-
"RLTrainer",
11+
"TitanTrainer",
12+
"RLTrainer", # Deprecated, use TitanTrainer
1013
"ReplayBuffer",
1114
"ReferenceModel",
1215
"SandboxedPythonCoder",
@@ -18,7 +21,17 @@ def __getattr__(name):
1821
from .generator import Generator
1922

2023
return Generator
24+
elif name == "TitanTrainer":
25+
from .trainer import TitanTrainer
26+
27+
return TitanTrainer
2128
elif name == "RLTrainer":
29+
warnings.warn(
30+
"RLTrainer is deprecated and will be removed in a future version. "
31+
"Please use TitanTrainer instead.",
32+
FutureWarning,
33+
stacklevel=2,
34+
)
2235
from .trainer import RLTrainer
2336

2437
return RLTrainer
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import warnings
8+
9+
from .titan import TitanTrainer
10+
11+
__all__ = ["TitanTrainer", "RLTrainer"]
12+
13+
14+
def __getattr__(name):
15+
if name == "RLTrainer":
16+
warnings.warn(
17+
"RLTrainer is deprecated and will be removed in a future version. "
18+
"Please use TitanTrainer instead.",
19+
FutureWarning,
20+
stacklevel=2,
21+
)
22+
return TitanTrainer
23+
raise AttributeError(f"module {__name__} has no attribute {name}")

src/forge/actors/trainer.py renamed to src/forge/actors/trainer/titan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353

5454

5555
@dataclass
56-
class RLTrainer(ForgeActor):
57-
"""A reinforcement learning trainer actor for policy optimization training.
56+
class TitanTrainer(ForgeActor):
57+
"""A generic trainer actor implementation built on top of TorchTitan.
5858
5959
Built on top of TorchTitan's training engine, this actor provides a complete training
6060
loop for reinforcement learning. It performs forward and backward passes with gradient

src/forge/controller/provisioner.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import socket
1313
import uuid
1414

15+
import torch
16+
1517
from monarch._src.actor.actor_mesh import ActorMesh
1618
from monarch._src.actor.shape import Extent
1719

@@ -41,8 +43,19 @@ class _RemoteInfoFetcher(Actor):
4143

4244
@endpoint
4345
def get_info(self) -> tuple[str, str]:
46+
"""Returns hostname and port."""
4447
return socket.gethostname(), _get_port()
4548

49+
@endpoint
50+
def get_gpu_count(self) -> int:
51+
"""Returns the number of GPUs available on this host."""
52+
try:
53+
gpu_count = torch.cuda.device_count()
54+
except Exception:
55+
# If torch is not available or CUDA is not available, assume no GPUs
56+
gpu_count = 0
57+
return gpu_count
58+
4659

4760
class EnvSetter(Actor):
4861
"""Actor to set environment variables on each proc in a mesh.
@@ -87,14 +100,26 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
87100
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
88101
fetcher = fetcher.slice(**singleton_slice)
89102
# Fetcher should be a singleton at this point - call_one() will fail otherwise
90-
91103
host, port = await fetcher.get_info.call_one()
92104

93105
# Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
94106
# await throwaway_procs.stop()
95107
return host, port
96108

97109

110+
async def get_host_gpus(host_mesh: HostMesh) -> int:
111+
"""Returns the number of GPUs available on the host mesh."""
112+
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
113+
fetcher = throwaway_procs.spawn("_gpu_counter", _RemoteInfoFetcher)
114+
115+
# Reduce to a singleton
116+
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
117+
fetcher = fetcher.slice(**singleton_slice)
118+
119+
gpu_count = await fetcher.get_gpu_count.call_one()
120+
return gpu_count
121+
122+
98123
async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
99124
"""Set environment variables on a proc mesh using EnvSetter actor.
100125
@@ -112,17 +137,35 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
112137
class GpuManager:
113138
"""Tracks and assigns GPU devices on a host.
114139
115-
This currently mimics the `gpu_manager` in system_controllers - we will
116-
consolidate as part of the "proper HostMesh integration" work.
140+
Args:
141+
available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
142+
max_device_count: Maximum number of GPU devices on this host. Defaults to 8.
117143
118144
"""
119145

120-
def __init__(self, available_devices: set[int] | None = None):
146+
def __init__(
147+
self, available_devices: set[int] | None = None, max_device_count: int = 8
148+
):
121149
if available_devices is None:
122-
available_devices = set(range(0, 8))
123-
assert all(isinstance(x, int) for x in available_devices)
124-
assert all(x >= 0 and x < 8 for x in available_devices)
150+
available_devices = set(range(0, max_device_count))
151+
else:
152+
# Validate types first
153+
assert all(
154+
isinstance(x, int) for x in available_devices
155+
), f"All device IDs must be integers, got: {available_devices}"
156+
# When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
157+
# adjust max_device_count to accommodate the highest device ID
158+
if available_devices:
159+
max_device_count = max(max(available_devices) + 1, max_device_count)
160+
161+
assert all(
162+
isinstance(x, int) for x in available_devices
163+
), f"All device IDs must be integers, got: {available_devices}"
164+
assert all(
165+
x >= 0 for x in available_devices
166+
), f"All device IDs must be non-negative, got: {available_devices}"
125167
self.available_gpus = available_devices
168+
self.max_device_count = max_device_count
126169

127170
def get_available_gpus(self) -> list[str]:
128171
"""Returns a list of available GPU devices."""
@@ -171,8 +214,18 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
171214
f"Invalid CUDA_VISIBLE_DEVICES format: '{cuda_visible_devices}'. "
172215
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
173216
) from e
217+
218+
# Get the actual GPU count for the local host
219+
try:
220+
local_gpu_count = torch.cuda.device_count()
221+
except Exception:
222+
# If torch is not available or CUDA is not available, assume no GPUs
223+
local_gpu_count = 0
224+
174225
self._host_gpu_map = {
175-
self._this_host_id: GpuManager(available_local_devices),
226+
self._this_host_id: GpuManager(
227+
available_local_devices, max_device_count=local_gpu_count
228+
),
176229
}
177230
self._proc_host_map = {}
178231
self._host_mesh_map = {}
@@ -277,7 +330,9 @@ async def get_proc_mesh(
277330
num_hosts=num_hosts,
278331
)
279332
host_id = uuid.uuid1()
280-
gpu_manager = GpuManager()
333+
# Get the GPU count from the remote host
334+
remote_gpu_count = await get_host_gpus(host_mesh)
335+
gpu_manager = GpuManager(max_device_count=remote_gpu_count)
281336
self._host_gpu_map[host_id] = gpu_manager
282337
host_mesh._host_id = host_id
283338
else:

src/forge/controller/system_controllers/__init__.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)