-
Notifications
You must be signed in to change notification settings - Fork 51
Add Multi-Node Distributed Training Support for SLURM Clusters #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,6 @@ | |
| """To run: | ||
|
|
||
| python -m apps.sft.main --config apps/sft/llama3_8b.yaml | ||
|
|
||
| """ | ||
|
|
||
| import asyncio | ||
|
|
@@ -23,11 +22,13 @@ | |
|
|
||
| import torchtitan.experiments.forge.train_spec as forge_train_spec | ||
| from forge.controller import ForgeActor | ||
| from forge.controller.provisioner import init_provisioner, shutdown | ||
| from forge.data.collate import collate_packed | ||
| from forge.data.datasets.packed import PackedDataset, TextPacker | ||
| from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset | ||
| from forge.data.tokenizer import HuggingFaceModelTokenizer | ||
| from forge.observability import get_or_create_metric_logger, record_metric, Reduce | ||
| from forge.types import LauncherConfig, ProvisionerConfig | ||
| from forge.util.config import parse | ||
|
|
||
| from monarch.actor import current_rank, current_size, endpoint | ||
|
|
@@ -41,8 +42,6 @@ | |
| from torchtitan.experiments.forge.engine import ForgeEngine | ||
| from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
|
|
||
| # from tqdm import tqdm | ||
|
|
||
| # stubs for now | ||
| Checkpointer = Any | ||
| Dataloader = Any | ||
|
|
@@ -78,10 +77,13 @@ def __init__(self, config: DictConfig): | |
|
|
||
| self.current_step = 0 | ||
| self.num_training_steps = job_config.training.steps | ||
| self.metric_logger = None # TODO: fix this | ||
HosseinKaviani-H marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.gradient_accumulation_steps = 1 # Example value, adjust as needed | ||
| self._rank = current_rank().rank | ||
| self._size = math.prod(current_size().values()) | ||
|
|
||
| self._init_dist() | ||
|
|
||
| super().__init__(job_config) | ||
|
|
||
| def _init_dist(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we remove the _init_dist altogether would this still work? I added this line in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to test this. Right now, I pass the local rank and NCCL variables within the env there. Will keep you posted.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should split out the SLURM specific PR from the SFT PR?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm that's a reasonable approach. I'll think of how to separate them and raise a new one |
||
|
|
@@ -94,9 +96,19 @@ def _init_dist(self): | |
| be explicit for now. | ||
|
|
||
| """ | ||
| # Calculate local rank - rank within the node | ||
| # For multi-node setups, LOCAL_RANK should be rank % gpus_per_node | ||
| size_info = current_size() | ||
|
|
||
| # size_info = {'hosts': 8, 'procs': 4} for 8 nodes with 4 GPUs each | ||
| local_world_size = ( | ||
| size_info.get("procs", self._size) if size_info else self._size | ||
| ) | ||
| local_rank = self._rank % local_world_size | ||
|
|
||
| env = { | ||
| "RANK": str(self._rank), | ||
| "LOCAL_RANK": str(self._rank), | ||
| "LOCAL_RANK": str(local_rank), | ||
| "LOCAL_WORLD_SIZE": str(self._size), | ||
| "GROUP_RANK": str(self._size), | ||
| "GROUP_WORLD_SIZE": str(self._size), | ||
|
|
@@ -105,12 +117,15 @@ def _init_dist(self): | |
| "ROLE_NAME": "rank", | ||
| "WORLD_SIZE": str(self._size), | ||
| "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", | ||
| # Add other environment variables as needed - NCCL related variables, etc | ||
| } | ||
| os.environ.update(env) | ||
| logger.info("env: {}".format(env)) | ||
|
|
||
| async def setup_metric_logger(self): | ||
| """Initialization happens in the main process. Here we just retrieve it""" | ||
| """Retrieve the already-initialized metric logger from main process""" | ||
| # Don't create new logger - it was already initialized in main process | ||
| # Just retrieve the existing one | ||
| mlogger = await get_or_create_metric_logger() | ||
| return mlogger | ||
|
|
||
|
|
@@ -123,8 +138,8 @@ def record_batch_metrics(self, data_metrics: list): | |
| @endpoint | ||
| async def setup(self): | ||
| self.train_dataloader = self.setup_data() | ||
| self.mlogger = await self.setup_metric_logger() | ||
|
|
||
| self.mlogger = await self.setup_metric_logger() | ||
| # self.train_dataloader = self.setup_data( | ||
| # self.train_config.train_dataset_config, | ||
| # self.train_config.train_dataloader_config, | ||
|
|
@@ -138,11 +153,16 @@ async def setup(self): | |
|
|
||
| # TODO: confirm that this is working properly | ||
| # Should also use load, not dcp_load | ||
|
|
||
| # Setup training data (first 90% of train split) | ||
|
|
||
| # Load checkpoint if resuming | ||
| self.checkpointer.load(step=self.current_step) | ||
| # self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||
| # self.logger = self.setup_logger(self.train_config.logger_config) | ||
|
|
||
| def setup_data(self): | ||
| """Setup data with configurable dataset path and split.""" | ||
| print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) | ||
| tokenizer = HuggingFaceModelTokenizer( | ||
| tokenizer_json_path=os.path.join( | ||
|
|
@@ -165,11 +185,26 @@ def setup_data(self): | |
| ), | ||
| ) | ||
|
|
||
| # Ultimately we probably want something like this | ||
| # packer = build_packing_strategy(packing_config) | ||
| # dataset = build_dataset(dataset_config) | ||
| # dataloader = build_dataloader(dataloader_config, dataset, packer) | ||
|
|
||
| # Get data config from YAML (num_shards_per_rank, num_dataloader_workers) | ||
| data_config = getattr(self.job_config, "data", None) | ||
| num_shards_per_rank = ( | ||
| getattr(data_config, "num_shards_per_rank", 64) if data_config else 64 | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| num_dataloader_workers = ( | ||
| getattr(data_config, "num_dataloader_workers", 0) if data_config else 0 | ||
| ) | ||
|
|
||
| dataset = sft_iterable_dataset( | ||
| model_transform=tokenizer, | ||
| message_transform=AlpacaToMessages(), | ||
| path="yahma/alpaca-cleaned", | ||
| split="train", | ||
| num_shards_per_rank=num_shards_per_rank, | ||
| ) | ||
| packer = TextPacker(padding_idx=0) | ||
| dataset = PackedDataset( | ||
|
|
@@ -180,15 +215,12 @@ def setup_data(self): | |
| dataloader = StatefulDataLoader( | ||
| dataset=dataset, | ||
| batch_size=self.job_config.training.local_batch_size, | ||
| num_workers=num_dataloader_workers, | ||
| collate_fn=partial( | ||
| collate_packed, mask_fn=packer.create_block_mask, device=self.device | ||
| ), | ||
| ) | ||
|
|
||
| # Ultimately we probably want something like this | ||
| # packer = build_packing_strategy(packing_config) | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # dataset = build_dataset(dataset_config) | ||
| # dataloader = build_dataloader(dataloader_config, dataset, packer) | ||
| return dataloader | ||
|
|
||
| def forward_backward( | ||
|
|
@@ -228,7 +260,6 @@ def forward_backward( | |
| ) | ||
|
|
||
| # accumulate losses across pipeline microbatches | ||
| # TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| loss = ( | ||
| torch.mean(torch.stack(losses)).to(self.device) | ||
| if self.pp_has_last_stage | ||
|
|
@@ -258,10 +289,12 @@ def train_step(self, batch) -> None: | |
| loss = self.forward_backward(batch, labels) | ||
| loss = loss.item() | ||
|
|
||
| # Record loss metric | ||
| record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) | ||
| logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") | ||
| # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||
| # self.pbar.update(1) | ||
|
|
||
| self.optimizers.step() | ||
| self.lr_schedulers.step() | ||
|
|
||
|
|
@@ -283,22 +316,22 @@ async def train(self) -> None: | |
| # Move tensors to the appropriate device | ||
| for k, v in batch.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| batch[k] = v.to("cuda") # TODO: hardcoded for now | ||
| batch[k] = v.to(self.device) # TODO: hardcoded for now | ||
|
|
||
| self.train_step(batch) | ||
| # self.profiler.step() | ||
| self.current_step += 1 | ||
|
|
||
| # Flush metrics | ||
| if self._rank == 0: | ||
| if self._rank == 0 and self.mlogger is not None: | ||
| logger.debug(f"Flushing metrics at step {self.current_step}") | ||
| await self.mlogger.flush.call_one(global_step=self.current_step) | ||
|
|
||
| # Save checkpoints | ||
| self.checkpointer.save( | ||
| curr_step=self.current_step, | ||
| last_step=self.current_step == self.num_training_steps, | ||
| ) | ||
|
|
||
| # self.pbar.close() | ||
|
|
||
| @endpoint | ||
|
|
@@ -313,28 +346,54 @@ def __repr__(self) -> str: | |
|
|
||
|
|
||
| async def run(cfg: DictConfig) -> None: | ||
| """Main SFT training loop with provisioner support for multi-node training.""" | ||
| # ---- Global setups ---- # | ||
| provisioner = None | ||
| if cfg.get("provisioner", None) is not None: | ||
| logging.info("Initializing provisioner with launcher configuration...") | ||
| provisioner = await init_provisioner( | ||
| ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) | ||
| ) | ||
| else: | ||
| logging.info("Initializing default provisioner...") | ||
| provisioner = await init_provisioner() | ||
|
|
||
| logging.info("Spawning recipe...") | ||
| process_cfg = cfg.pop("processes") | ||
|
|
||
| # Initialize metric logger in main process | ||
| # ---- Initialize metric logger in main process ---- # | ||
| metric_logging_cfg = cfg.get("metric_logging", {}) | ||
| mlogger = await get_or_create_metric_logger(process_name="Controller") | ||
| await mlogger.init_backends.call_one(metric_logging_cfg) | ||
|
|
||
| recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg) | ||
| # ---- Setup SFT Recipe Actor ---- # | ||
| logging.info("Spawning recipe...") | ||
| actor_cfg = cfg.pop("actors", None) | ||
|
|
||
| if actor_cfg is None: | ||
| # Fallback to old "processes" config for backward compatibility | ||
| actor_cfg = cfg.pop("processes", {"procs": 8, "with_gpus": True}) | ||
| logging.warning( | ||
| "Using legacy 'processes' config. Consider migrating to 'actors' config." | ||
| ) | ||
|
|
||
| recipe_options = actor_cfg.get("trainer", actor_cfg) | ||
| recipe = await ForgeSFTRecipe.options(**recipe_options).as_actor(cfg) | ||
|
|
||
| logging.info("Created recipe, running setup.") | ||
| await recipe.setup.call() | ||
|
|
||
| logging.info("Recipe has been setup. Training now.") | ||
| await recipe.train.call() | ||
|
|
||
| logging.info("Done training. Clean up") | ||
| await recipe.cleanup.call() | ||
|
|
||
| await recipe.mesh.stop() | ||
| logging.info("All done!") | ||
| try: | ||
| await recipe.train.call() | ||
| except KeyboardInterrupt: | ||
| logging.info("Training interrupted by user") | ||
| finally: | ||
| logging.info("Done training. Clean up") | ||
| await recipe.cleanup.call() | ||
| await ForgeSFTRecipe.shutdown(recipe) | ||
|
|
||
| # Shutdown provisioner | ||
| await shutdown() | ||
| logging.info("All done!") | ||
|
|
||
|
|
||
| @parse | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Multi-Node SFT Configuration for Qwen3-32B | ||
| # >>> python -m apps.sft.main --config apps/sft/qwen3_32b_multinode.yaml | ||
|
|
||
| comm: | ||
| trace_buf_size: 0 | ||
|
|
||
| model_name: "Qwen/Qwen3-32B" | ||
|
|
||
| provisioner: | ||
| launcher: slurm | ||
| cpu: # CPUs per node - if emtpy, will be inferred from Slurm | ||
| memory_mb: # Memory in MB per node - if emtpy, will be inferred from Slurm | ||
| gpus_per_node: # Number of GPUs per node - if emtpy, will be inferred from Slurm | ||
|
|
||
| # Actor configuration for multi-node training | ||
| actors: | ||
| trainer: | ||
| procs: 4 # Number of GPU processes per node | ||
| hosts: 64 # Number of nodes to use | ||
| with_gpus: true | ||
| mesh_name: trainer | ||
|
|
||
| model: | ||
| name: qwen3 | ||
| flavor: 32B | ||
| hf_assets_path: hf://${model} | ||
|
|
||
| optimizer: | ||
| name: AdamW | ||
| lr: 1e-5 | ||
| eps: 1e-8 | ||
|
|
||
| lr_scheduler: | ||
| warmup_steps: 200 | ||
|
|
||
| training: | ||
| local_batch_size: 1 | ||
| seq_len: 2048 | ||
| max_norm: 1.0 | ||
| steps: 1000000 | ||
| compile: false | ||
| dataset: "c4" | ||
|
|
||
|
|
||
| data: | ||
| # This is needed to be adjusted based on the dataset size and world size - sample size >= world size * num_shards_per_rank | ||
| num_shards_per_rank: 64 # Default: 64 | ||
| num_dataloader_workers: 0 # 0 = no worker processes | ||
|
|
||
|
|
||
| parallelism: | ||
| data_parallel_replicate_degree: 1 | ||
| data_parallel_shard_degree: -1 | ||
| tensor_parallel_degree: 1 | ||
| pipeline_parallel_degree: 1 | ||
| context_parallel_degree: 1 | ||
| expert_parallel_degree: 1 | ||
| disable_loss_parallel: false | ||
|
|
||
| checkpoint: | ||
| enable: true | ||
| folder: ./checkpoints | ||
| # To fine-tune from pre-trained HF model (base model), uncomment these: | ||
| initial_load_path: hf://${model} | ||
| initial_load_in_hf: true | ||
| last_save_in_hf: true | ||
| interval: 500 | ||
| async_mode: "disabled" # Save checkpoints in background without blocking training | ||
|
|
||
| activation_checkpoint: | ||
| mode: full | ||
|
|
||
| # Metric logging configuration | ||
| metric_logging: | ||
| wandb: | ||
| project: sft-training | ||
| group: sft_exp_${oc.env:USER} | ||
| logging_mode: global_reduce #global_reduce, per_rank_reduce, per_rank_no_reduce | ||
| # console: | ||
| # reduce_across_ranks: True | ||
|
|
||
| # Optional: Profiling configuration | ||
| # profiling: | ||
| # enable_profiling: false | ||
|
|
||
| # Optional: Metrics configuration | ||
| # metrics: | ||
| # log_freq: 10 | ||
| # enable_tensorboard: true | ||
| # save_tb_folder: "tb" |
Uh oh!
There was an error while loading. Please reload this page.