Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm
cpu: # CPUs per node - if emtpy, will be inferred from Slurm
memory_mb: # Memory in MB per node - if emtpy, will be inferred from Slurm
gpus_per_node: # Number of GPUs per node - if emtpy, will be inferred from Slurm

# Main loop configuration
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well
Expand Down
109 changes: 84 additions & 25 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""To run:

python -m apps.sft.main --config apps/sft/llama3_8b.yaml

"""

import asyncio
Expand All @@ -23,11 +22,13 @@

import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.controller import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.config import parse

from monarch.actor import current_rank, current_size, endpoint
Expand All @@ -41,8 +42,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

# from tqdm import tqdm

# stubs for now
Checkpointer = Any
Dataloader = Any
Expand Down Expand Up @@ -78,10 +77,13 @@ def __init__(self, config: DictConfig):

self.current_step = 0
self.num_training_steps = job_config.training.steps
self.metric_logger = None # TODO: fix this
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

self._init_dist()

super().__init__(job_config)

def _init_dist(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we remove the _init_dist altogether would this still work? I added this line in get_proc_mesh later, so this should not be needed anymore. Could you please try it out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to test this. Right now, I pass the local rank and NCCL variables within the env there. Will keep you posted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should split out the SLURM specific PR from the SFT PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that's a reasonable approach. I'll think of how to separate them and raise a new one

Expand All @@ -94,9 +96,19 @@ def _init_dist(self):
be explicit for now.

"""
# Calculate local rank - rank within the node
# For multi-node setups, LOCAL_RANK should be rank % gpus_per_node
size_info = current_size()

# size_info = {'hosts': 8, 'procs': 4} for 8 nodes with 4 GPUs each
local_world_size = (
size_info.get("procs", self._size) if size_info else self._size
)
local_rank = self._rank % local_world_size

env = {
"RANK": str(self._rank),
"LOCAL_RANK": str(self._rank),
"LOCAL_RANK": str(local_rank),
"LOCAL_WORLD_SIZE": str(self._size),
"GROUP_RANK": str(self._size),
"GROUP_WORLD_SIZE": str(self._size),
Expand All @@ -105,12 +117,15 @@ def _init_dist(self):
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self._size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
# Add other environment variables as needed - NCCL related variables, etc
}
os.environ.update(env)
logger.info("env: {}".format(env))

async def setup_metric_logger(self):
"""Initialization happens in the main process. Here we just retrieve it"""
"""Retrieve the already-initialized metric logger from main process"""
# Don't create new logger - it was already initialized in main process
# Just retrieve the existing one
mlogger = await get_or_create_metric_logger()
return mlogger

Expand All @@ -123,8 +138,8 @@ def record_batch_metrics(self, data_metrics: list):
@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
self.mlogger = await self.setup_metric_logger()

self.mlogger = await self.setup_metric_logger()
# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
Expand All @@ -138,11 +153,16 @@ async def setup(self):

# TODO: confirm that this is working properly
# Should also use load, not dcp_load

# Setup training data (first 90% of train split)

# Load checkpoint if resuming
self.checkpointer.load(step=self.current_step)
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
"""Setup data with configurable dataset path and split."""
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
Expand All @@ -165,11 +185,26 @@ def setup_data(self):
),
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)

# Get data config from YAML (num_shards_per_rank, num_dataloader_workers)
data_config = getattr(self.job_config, "data", None)
num_shards_per_rank = (
getattr(data_config, "num_shards_per_rank", 64) if data_config else 64
)
num_dataloader_workers = (
getattr(data_config, "num_dataloader_workers", 0) if data_config else 0
)

dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
path="yahma/alpaca-cleaned",
split="train",
num_shards_per_rank=num_shards_per_rank,
)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
Expand All @@ -180,15 +215,12 @@ def setup_data(self):
dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
num_workers=num_dataloader_workers,
collate_fn=partial(
collate_packed, mask_fn=packer.create_block_mask, device=self.device
),
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader

def forward_backward(
Expand Down Expand Up @@ -228,7 +260,6 @@ def forward_backward(
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
Expand Down Expand Up @@ -258,10 +289,12 @@ def train_step(self, batch) -> None:
loss = self.forward_backward(batch, labels)
loss = loss.item()

# Record loss metric
record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)

self.optimizers.step()
self.lr_schedulers.step()

Expand All @@ -283,22 +316,22 @@ async def train(self) -> None:
# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
batch[k] = v.to(self.device) # TODO: hardcoded for now

self.train_step(batch)
# self.profiler.step()
self.current_step += 1

# Flush metrics
if self._rank == 0:
if self._rank == 0 and self.mlogger is not None:
logger.debug(f"Flushing metrics at step {self.current_step}")
await self.mlogger.flush.call_one(global_step=self.current_step)

# Save checkpoints
self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# self.pbar.close()

@endpoint
Expand All @@ -313,28 +346,54 @@ def __repr__(self) -> str:


async def run(cfg: DictConfig) -> None:
"""Main SFT training loop with provisioner support for multi-node training."""
# ---- Global setups ---- #
provisioner = None
if cfg.get("provisioner", None) is not None:
logging.info("Initializing provisioner with launcher configuration...")
provisioner = await init_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
else:
logging.info("Initializing default provisioner...")
provisioner = await init_provisioner()

logging.info("Spawning recipe...")
process_cfg = cfg.pop("processes")

# Initialize metric logger in main process
# ---- Initialize metric logger in main process ---- #
metric_logging_cfg = cfg.get("metric_logging", {})
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)
# ---- Setup SFT Recipe Actor ---- #
logging.info("Spawning recipe...")
actor_cfg = cfg.pop("actors", None)

if actor_cfg is None:
# Fallback to old "processes" config for backward compatibility
actor_cfg = cfg.pop("processes", {"procs": 8, "with_gpus": True})
logging.warning(
"Using legacy 'processes' config. Consider migrating to 'actors' config."
)

recipe_options = actor_cfg.get("trainer", actor_cfg)
recipe = await ForgeSFTRecipe.options(**recipe_options).as_actor(cfg)

logging.info("Created recipe, running setup.")
await recipe.setup.call()

logging.info("Recipe has been setup. Training now.")
await recipe.train.call()

logging.info("Done training. Clean up")
await recipe.cleanup.call()

await recipe.mesh.stop()
logging.info("All done!")
try:
await recipe.train.call()
except KeyboardInterrupt:
logging.info("Training interrupted by user")
finally:
logging.info("Done training. Clean up")
await recipe.cleanup.call()
await ForgeSFTRecipe.shutdown(recipe)

# Shutdown provisioner
await shutdown()
logging.info("All done!")


@parse
Expand Down
90 changes: 90 additions & 0 deletions apps/sft/qwen3_32b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Multi-Node SFT Configuration for Qwen3-32B
# >>> python -m apps.sft.main --config apps/sft/qwen3_32b_multinode.yaml

comm:
trace_buf_size: 0

model_name: "Qwen/Qwen3-32B"

provisioner:
launcher: slurm
cpu: # CPUs per node - if emtpy, will be inferred from Slurm
memory_mb: # Memory in MB per node - if emtpy, will be inferred from Slurm
gpus_per_node: # Number of GPUs per node - if emtpy, will be inferred from Slurm

# Actor configuration for multi-node training
actors:
trainer:
procs: 4 # Number of GPU processes per node
hosts: 64 # Number of nodes to use
with_gpus: true
mesh_name: trainer

model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}

optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8

lr_scheduler:
warmup_steps: 200

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 1000000
compile: false
dataset: "c4"


data:
# This is needed to be adjusted based on the dataset size and world size - sample size >= world size * num_shards_per_rank
num_shards_per_rank: 64 # Default: 64
num_dataloader_workers: 0 # 0 = no worker processes


parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: false

checkpoint:
enable: true
folder: ./checkpoints
# To fine-tune from pre-trained HF model (base model), uncomment these:
initial_load_path: hf://${model}
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled" # Save checkpoints in background without blocking training

activation_checkpoint:
mode: full

# Metric logging configuration
metric_logging:
wandb:
project: sft-training
group: sft_exp_${oc.env:USER}
logging_mode: global_reduce #global_reduce, per_rank_reduce, per_rank_no_reduce
# console:
# reduce_across_ranks: True

# Optional: Profiling configuration
# profiling:
# enable_profiling: false

# Optional: Metrics configuration
# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
Loading
Loading