1515from omegaconf import DictConfig , ListConfig
1616
1717from torch import nn
18- from torch .distributed import destroy_process_group , init_process_group
19-
18+ from torch .distributed import (
19+ destroy_process_group ,
20+ init_device_mesh ,
21+ init_process_group ,
22+ )
23+ from torch .distributed ._tensor import DTensor
24+ from torch .distributed .tensor .parallel import parallelize_module
2025from torch .optim import Optimizer
2126from torch .utils .data import DataLoader , DistributedSampler
2227from torchtune import config , modules , training , utils
@@ -136,14 +141,26 @@ def __init__(self, cfg: DictConfig) -> None:
136141 or self ._enable_async_checkpointing ,
137142 )
138143 init_process_group (self .distributed_backend )
139- _ , rank = utils .get_world_size_and_rank ()
140- self ._is_rank_zero = rank == 0
144+
145+ # Initialize distributed variables
146+ self .world_size , self .rank = utils .get_world_size_and_rank ()
147+ self ._is_rank_zero = self .rank == 0
148+ self .parallelize_plan = config .instantiate (cfg .get ("parallelize_plan" , None ))
149+ self .tensor_parallel_dim = cfg .get ("tensor_parallel_dim" , 1 )
150+ if self .tensor_parallel_dim > 1 and self .parallelize_plan is None :
151+ raise ValueError (
152+ "Parallelism plan need to be provided when tensor parallel is enabled."
153+ )
154+ if self .world_size % self .tensor_parallel_dim != 0 :
155+ raise ValueError (
156+ f"world_size { self .world_size } must be divisible by tensor_parallel_dim { self .tensor_parallel_dim } "
157+ )
158+ self .data_parallel_dim = self .world_size // self .tensor_parallel_dim
141159
142160 # Logging attributes
143161 self ._output_dir = cfg .output_dir
144162 self ._log_every_n_steps = cfg .get ("log_every_n_steps" , 1 )
145163 self ._log_peak_memory_stats = cfg .get ("log_peak_memory_stats" , False )
146-
147164 if self ._log_peak_memory_stats and device_type != "cuda" :
148165 log .info (
149166 "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
@@ -505,7 +522,7 @@ def _setup_model(
505522
506523 utils .log_rank_zero (
507524 log ,
508- "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." ,
525+ "Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ..." ,
509526 )
510527 init_start = time .perf_counter ()
511528
@@ -515,6 +532,24 @@ def _setup_model(
515532 if self ._compile :
516533 training .compile_model (model , verbose = self ._is_rank_zero )
517534
535+ device_mesh = init_device_mesh (
536+ self ._device .type ,
537+ mesh_shape = (self .data_parallel_dim , self .tensor_parallel_dim ),
538+ mesh_dim_names = ("dp" , "tp" ),
539+ )
540+ self .dp_size = device_mesh ["dp" ].size ()
541+ self .dp_rank = device_mesh ["dp" ].get_local_rank ()
542+
543+ # Apply tensor parallelism to the model
544+ if self .tensor_parallel_dim > 1 :
545+ # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
546+ model = training .prepare_mha_for_tp (model , device_mesh ["tp" ])
547+ parallelize_module (
548+ model ,
549+ device_mesh ["tp" ],
550+ parallelize_plan = self .parallelize_plan ,
551+ )
552+
518553 # We currently have two versions of activation checkpointing in this recipe
519554 # for testing and BC purposes. ``enable_activation_checkpointing`` controls
520555 # the older version of AC and this behavior is unchanged
@@ -534,19 +569,21 @@ def _setup_model(
534569 model , auto_wrap_policy = {modules .TransformerSelfAttentionLayer }
535570 )
536571
537- # For FSDP sharding
538- fsdp_shard_conditions = [
539- partial (
540- training .get_shard_conditions ,
541- names_to_match = custom_sharded_layers ,
572+ # Apply Fully Sharded Data Parallelism to the model
573+ if self .data_parallel_dim > 1 :
574+ fsdp_shard_conditions = [
575+ partial (
576+ training .get_shard_conditions ,
577+ names_to_match = custom_sharded_layers ,
578+ )
579+ ]
580+ training .shard_model (
581+ model = model ,
582+ shard_conditions = fsdp_shard_conditions ,
583+ cpu_offload = fsdp_cpu_offload ,
584+ reshard_after_forward = reshard_after_forward ,
585+ dp_mesh = device_mesh ["dp" ],
542586 )
543- ]
544- training .shard_model (
545- model = model ,
546- shard_conditions = fsdp_shard_conditions ,
547- cpu_offload = fsdp_cpu_offload ,
548- reshard_after_forward = reshard_after_forward ,
549- )
550587
551588 with training .set_default_dtype (self ._dtype ), self ._device :
552589 for m in model .modules ():
@@ -651,8 +688,6 @@ def _setup_data(
651688 DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
652689 iterable datasets and streaming datasets are not supported.
653690 """
654- world_size , rank = utils .get_world_size_and_rank ()
655-
656691 if isinstance (cfg_dataset , ListConfig ):
657692 datasets = [
658693 config .instantiate (single_cfg_dataset , self ._tokenizer )
@@ -670,7 +705,7 @@ def _setup_data(
670705 collate_fn = _get_component_from_path (collate_fn )
671706
672707 sampler = DistributedSampler (
673- ds , num_replicas = world_size , rank = rank , shuffle = shuffle , seed = 0
708+ ds , num_replicas = self . dp_size , rank = self . dp_rank , shuffle = shuffle , seed = 0
674709 )
675710 dataloader = DataLoader (
676711 dataset = ds ,
@@ -700,8 +735,6 @@ def train(self) -> None:
700735 # clean up before training begins
701736 training .cleanup_before_training ()
702737
703- world_size , rank = utils .get_world_size_and_rank ()
704-
705738 # zero out the gradients before starting training
706739 if not self ._optimizer_in_bwd :
707740 self ._optimizer .zero_grad ()
@@ -721,7 +754,7 @@ def train(self) -> None:
721754 # in case shuffle is True
722755 self ._sampler .set_epoch (curr_epoch )
723756
724- pbar = tqdm (total = self ._steps_per_epoch , disable = not ( rank == 0 ) )
757+ pbar = tqdm (total = self ._steps_per_epoch , disable = not self . _is_rank_zero )
725758 for idx , batch in enumerate (self ._dataloader ):
726759 if (
727760 self .max_steps_per_epoch is not None
@@ -739,7 +772,6 @@ def train(self) -> None:
739772 and self ._device .type == "cuda"
740773 ):
741774 torch .cuda .memory ._record_memory_history ()
742-
743775 utils .batch_to_device (batch , self ._device )
744776
745777 # Calculate the number of unmasked tokens in the current batch
@@ -782,7 +814,7 @@ def train(self) -> None:
782814 torch .distributed .all_reduce (running_loss )
783815
784816 # We multiply by world_size to undo FSDP2 gradient normalization.
785- current_loss = current_loss * (world_size / num_tokens )
817+ current_loss = current_loss * (self . world_size / num_tokens )
786818
787819 current_loss .backward ()
788820
@@ -795,12 +827,15 @@ def train(self) -> None:
795827 torch .distributed .all_reduce (running_loss )
796828 # Manually scale the gradients from unnormalized loss by total # of tokens
797829 # We multiply by world_size to undo FSDP2 gradient normalization.
798- training .scale_grads (self ._model , world_size / num_tokens )
830+ training .scale_grads (self ._model , self . world_size / num_tokens )
799831 if self ._clip_grad_norm is not None :
800832 grad_norm = torch .nn .utils .clip_grad_norm_ (
801833 self ._model .parameters (),
802834 max_norm = float (self ._clip_grad_norm ),
803- ).full_tensor ()
835+ )
836+ # If sharded, collect the DTensor here
837+ if isinstance (grad_norm , DTensor ):
838+ grad_norm = grad_norm .full_tensor ()
804839 self ._optimizer .step ()
805840 self ._optimizer .zero_grad (set_to_none = True )
806841
@@ -833,7 +868,7 @@ def train(self) -> None:
833868 ),
834869 ),
835870 "tokens_per_second_per_gpu" : num_tokens
836- / (time_per_step * world_size ),
871+ / (time_per_step * self . world_size ),
837872 }
838873 if self ._log_peak_memory_stats :
839874 log_dict .update (
0 commit comments