Skip to content

Commit 7754294

Browse files
author
pytorchbot
committed
2025-02-10 nightly release (9da35c7)
1 parent 08f9786 commit 7754294

File tree

7 files changed

+164
-40
lines changed

7 files changed

+164
-40
lines changed

recipes/configs/llama3/70B_full.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference.
2121

22+
# Parallelism
23+
tensor_parallel_dim: 1
24+
parallelize_plan:
25+
_component_: torchtune.models.llama3.base_llama_tp_plan
26+
2227
# Tokenizer
2328
tokenizer:
2429
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -54,7 +59,7 @@ epochs: 1
5459
optimizer:
5560
_component_: torch.optim.AdamW
5661
lr: 2e-5
57-
fused: True
62+
fused: False
5863

5964
loss:
6065
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

recipes/configs/llama3_1/70B_full.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818

1919
output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference.
2020

21+
# Parallelism
22+
tensor_parallel_dim: 1
23+
parallelize_plan:
24+
_component_: torchtune.models.llama3.base_llama_tp_plan
25+
2126
# Tokenizer
2227
tokenizer:
2328
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -55,7 +60,7 @@ optimizer:
5560
lr: 2e-5
5661
# Note: highly recommended to use fused=True optimizer flag
5762
# with CPU offload for faster optimizer step.
58-
fused: True
63+
fused: False
5964

6065
loss:
6166
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

recipes/configs/llama3_3/70B_full.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818

1919
output_dir: /tmp/torchtune/llama3_3_70B/full # /tmp may be deleted by your system. Change it to your preference.
2020

21+
# Parallelism
22+
tensor_parallel_dim: 1
23+
parallelize_plan:
24+
_component_: torchtune.models.llama3.base_llama_tp_plan
25+
2126
# Tokenizer
2227
tokenizer:
2328
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -55,7 +60,7 @@ optimizer:
5560
lr: 2e-5
5661
# Note: highly recommended to use fused=True optimizer flag
5762
# with CPU offload for faster optimizer step.
58-
fused: True
63+
fused: False
5964

6065
loss:
6166
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

recipes/dev/generate_v2_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def setup(self, cfg: DictConfig) -> None:
105105
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)
106106

107107
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
108-
training.prepare_mha_for_tp(model, tp_device_mesh)
108+
model = training.prepare_mha_for_tp(model, tp_device_mesh)
109109
parallelize_module(
110110
model,
111111
tp_device_mesh,

recipes/full_finetune_distributed.py

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
from omegaconf import DictConfig, ListConfig
1616

1717
from torch import nn
18-
from torch.distributed import destroy_process_group, init_process_group
19-
18+
from torch.distributed import (
19+
destroy_process_group,
20+
init_device_mesh,
21+
init_process_group,
22+
)
23+
from torch.distributed._tensor import DTensor
24+
from torch.distributed.tensor.parallel import parallelize_module
2025
from torch.optim import Optimizer
2126
from torch.utils.data import DataLoader, DistributedSampler
2227
from torchtune import config, modules, training, utils
@@ -136,14 +141,26 @@ def __init__(self, cfg: DictConfig) -> None:
136141
or self._enable_async_checkpointing,
137142
)
138143
init_process_group(self.distributed_backend)
139-
_, rank = utils.get_world_size_and_rank()
140-
self._is_rank_zero = rank == 0
144+
145+
# Initialize distributed variables
146+
self.world_size, self.rank = utils.get_world_size_and_rank()
147+
self._is_rank_zero = self.rank == 0
148+
self.parallelize_plan = config.instantiate(cfg.get("parallelize_plan", None))
149+
self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
150+
if self.tensor_parallel_dim > 1 and self.parallelize_plan is None:
151+
raise ValueError(
152+
"Parallelism plan need to be provided when tensor parallel is enabled."
153+
)
154+
if self.world_size % self.tensor_parallel_dim != 0:
155+
raise ValueError(
156+
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
157+
)
158+
self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
141159

142160
# Logging attributes
143161
self._output_dir = cfg.output_dir
144162
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
145163
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
146-
147164
if self._log_peak_memory_stats and device_type != "cuda":
148165
log.info(
149166
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
@@ -505,7 +522,7 @@ def _setup_model(
505522

506523
utils.log_rank_zero(
507524
log,
508-
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
525+
"Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
509526
)
510527
init_start = time.perf_counter()
511528

@@ -515,6 +532,24 @@ def _setup_model(
515532
if self._compile:
516533
training.compile_model(model, verbose=self._is_rank_zero)
517534

535+
device_mesh = init_device_mesh(
536+
self._device.type,
537+
mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
538+
mesh_dim_names=("dp", "tp"),
539+
)
540+
self.dp_size = device_mesh["dp"].size()
541+
self.dp_rank = device_mesh["dp"].get_local_rank()
542+
543+
# Apply tensor parallelism to the model
544+
if self.tensor_parallel_dim > 1:
545+
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
546+
model = training.prepare_mha_for_tp(model, device_mesh["tp"])
547+
parallelize_module(
548+
model,
549+
device_mesh["tp"],
550+
parallelize_plan=self.parallelize_plan,
551+
)
552+
518553
# We currently have two versions of activation checkpointing in this recipe
519554
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
520555
# the older version of AC and this behavior is unchanged
@@ -534,19 +569,21 @@ def _setup_model(
534569
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
535570
)
536571

537-
# For FSDP sharding
538-
fsdp_shard_conditions = [
539-
partial(
540-
training.get_shard_conditions,
541-
names_to_match=custom_sharded_layers,
572+
# Apply Fully Sharded Data Parallelism to the model
573+
if self.data_parallel_dim > 1:
574+
fsdp_shard_conditions = [
575+
partial(
576+
training.get_shard_conditions,
577+
names_to_match=custom_sharded_layers,
578+
)
579+
]
580+
training.shard_model(
581+
model=model,
582+
shard_conditions=fsdp_shard_conditions,
583+
cpu_offload=fsdp_cpu_offload,
584+
reshard_after_forward=reshard_after_forward,
585+
dp_mesh=device_mesh["dp"],
542586
)
543-
]
544-
training.shard_model(
545-
model=model,
546-
shard_conditions=fsdp_shard_conditions,
547-
cpu_offload=fsdp_cpu_offload,
548-
reshard_after_forward=reshard_after_forward,
549-
)
550587

551588
with training.set_default_dtype(self._dtype), self._device:
552589
for m in model.modules():
@@ -651,8 +688,6 @@ def _setup_data(
651688
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
652689
iterable datasets and streaming datasets are not supported.
653690
"""
654-
world_size, rank = utils.get_world_size_and_rank()
655-
656691
if isinstance(cfg_dataset, ListConfig):
657692
datasets = [
658693
config.instantiate(single_cfg_dataset, self._tokenizer)
@@ -670,7 +705,7 @@ def _setup_data(
670705
collate_fn = _get_component_from_path(collate_fn)
671706

672707
sampler = DistributedSampler(
673-
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
708+
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
674709
)
675710
dataloader = DataLoader(
676711
dataset=ds,
@@ -700,8 +735,6 @@ def train(self) -> None:
700735
# clean up before training begins
701736
training.cleanup_before_training()
702737

703-
world_size, rank = utils.get_world_size_and_rank()
704-
705738
# zero out the gradients before starting training
706739
if not self._optimizer_in_bwd:
707740
self._optimizer.zero_grad()
@@ -721,7 +754,7 @@ def train(self) -> None:
721754
# in case shuffle is True
722755
self._sampler.set_epoch(curr_epoch)
723756

724-
pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
757+
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
725758
for idx, batch in enumerate(self._dataloader):
726759
if (
727760
self.max_steps_per_epoch is not None
@@ -739,7 +772,6 @@ def train(self) -> None:
739772
and self._device.type == "cuda"
740773
):
741774
torch.cuda.memory._record_memory_history()
742-
743775
utils.batch_to_device(batch, self._device)
744776

745777
# Calculate the number of unmasked tokens in the current batch
@@ -782,7 +814,7 @@ def train(self) -> None:
782814
torch.distributed.all_reduce(running_loss)
783815

784816
# We multiply by world_size to undo FSDP2 gradient normalization.
785-
current_loss = current_loss * (world_size / num_tokens)
817+
current_loss = current_loss * (self.world_size / num_tokens)
786818

787819
current_loss.backward()
788820

@@ -795,12 +827,15 @@ def train(self) -> None:
795827
torch.distributed.all_reduce(running_loss)
796828
# Manually scale the gradients from unnormalized loss by total # of tokens
797829
# We multiply by world_size to undo FSDP2 gradient normalization.
798-
training.scale_grads(self._model, world_size / num_tokens)
830+
training.scale_grads(self._model, self.world_size / num_tokens)
799831
if self._clip_grad_norm is not None:
800832
grad_norm = torch.nn.utils.clip_grad_norm_(
801833
self._model.parameters(),
802834
max_norm=float(self._clip_grad_norm),
803-
).full_tensor()
835+
)
836+
# If sharded, collect the DTensor here
837+
if isinstance(grad_norm, DTensor):
838+
grad_norm = grad_norm.full_tensor()
804839
self._optimizer.step()
805840
self._optimizer.zero_grad(set_to_none=True)
806841

@@ -833,7 +868,7 @@ def train(self) -> None:
833868
),
834869
),
835870
"tokens_per_second_per_gpu": num_tokens
836-
/ (time_per_step * world_size),
871+
/ (time_per_step * self.world_size),
837872
}
838873
if self._log_peak_memory_stats:
839874
log_dict.update(

tests/recipes/test_full_finetune_distributed.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,75 @@ def test_loss(
129129
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
130130
)
131131

132+
@pytest.mark.integration_test
133+
@pytest.mark.parametrize(
134+
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd, tensor_parallel_dim",
135+
[
136+
("llama3/8B_full", "llama3", "tune", 4, 1, True, 2),
137+
("llama3/8B_full", "llama3", "tune", 4, 1, True, 4),
138+
],
139+
)
140+
@gpu_test(gpu_count=4)
141+
def test_loss_2d_parallel(
142+
self,
143+
micro_batch_size,
144+
gradient_accumulation_steps,
145+
config,
146+
model_type,
147+
ckpt_type,
148+
optim_in_bwd,
149+
tensor_parallel_dim,
150+
tmpdir,
151+
monkeypatch,
152+
):
153+
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
154+
ckpt = model_type + "_" + ckpt_type
155+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
156+
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
157+
ckpt_dir = ckpt_path.parent
158+
log_file = gen_log_file_name(tmpdir)
159+
parallelize_plan = "torchtune.models.llama3.base_llama_tp_plan"
160+
161+
# Config file needed for model conversion.
162+
write_hf_ckpt_config(ckpt_dir)
163+
164+
cmd = f"""
165+
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
166+
--config {config} \
167+
batch_size={micro_batch_size} \
168+
gradient_accumulation_steps={gradient_accumulation_steps} \
169+
output_dir={tmpdir} \
170+
checkpointer._component_={ckpt_component} \
171+
checkpointer.checkpoint_dir='{ckpt_dir}' \
172+
checkpointer.checkpoint_files=[{ckpt_path}]\
173+
checkpointer.output_dir={tmpdir} \
174+
checkpointer.model_type={model_type.upper()} \
175+
tokenizer.path='{tokenizer_path}' \
176+
tokenizer.prompt_template=null \
177+
tensor_parallel_dim={tensor_parallel_dim} \
178+
parallelize_plan._component_={parallelize_plan} \
179+
metric_logger.filename={log_file} \
180+
""".split()
181+
model_config = MODEL_TEST_CONFIGS[model_type]
182+
cmd = cmd + self._get_test_config_overrides() + model_config
183+
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
184+
# wrong grad_norm, so we only test one of them each time. But loss values
185+
# should be the same.
186+
if not optim_in_bwd:
187+
cmd.append("clip_grad_norm=100")
188+
# Test that gradient clipping works with CPU offload
189+
cmd.append("fsdp_cpu_offload=True")
190+
else:
191+
cmd.append("optimizer_in_bwd=True")
192+
193+
monkeypatch.setattr(sys, "argv", cmd)
194+
runpy.run_path(TUNE_PATH, run_name="__main__")
195+
loss_values = get_loss_values_from_metric_logger(log_file)
196+
expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type)
197+
torch.testing.assert_close(
198+
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
199+
)
200+
132201
@pytest.mark.integration_test
133202
@pytest.mark.parametrize(
134203
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",

torchtune/training/_distributed.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
3232
from torchtune.modules import TransformerDecoder
3333
from torchtune.modules.attention import MultiHeadAttention
34-
from torchtune.modules.model_fusion import DeepFusionModel
35-
34+
from torchtune.modules.model_fusion import DeepFusionModel, EarlyFusionModel
3635
from torchtune.modules.peft import get_adapter_state_dict
3736
from torchtune.utils import get_device, get_logger
3837
from torchtune.utils._logging import deprecated
@@ -523,6 +522,7 @@ def shard_model(
523522
*,
524523
cpu_offload: bool,
525524
reshard_after_forward: bool = True,
525+
dp_mesh: Optional[DeviceMesh] = None,
526526
) -> None:
527527
"""
528528
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
@@ -541,11 +541,13 @@ def shard_model(
541541
reshard_after_forward (bool): Whether to reshard parameters and buffers after
542542
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
543543
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
544+
dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under mutliple parallelism.
545+
Default to None.
544546
545547
Raises:
546548
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
547549
"""
548-
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
550+
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward, "mesh": dp_mesh}
549551
if cpu_offload:
550552
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
551553

@@ -599,11 +601,11 @@ def prepare_mha_for_tp(
599601
>>> # num_kv_heads = 16 (32/2)
600602
>>> # embed_dim = 2048 (4096/2)
601603
"""
602-
# Consider the case of Deep Fusion models
603-
if isinstance(model, DeepFusionModel):
604-
model = model.decoder
604+
# Handle fusion models by extracting decoder
605+
is_fusion_model = isinstance(model, (DeepFusionModel, EarlyFusionModel))
606+
decoder = model.decoder if is_fusion_model else model
605607
tp_size = tp_mesh.size()
606-
for m in list(model.modules()):
608+
for m in list(decoder.modules()):
607609
if isinstance(m, MultiHeadAttention):
608610
# Adjust attention module to use the local number of heads
609611
if m.num_heads % tp_size != 0:
@@ -624,4 +626,7 @@ def prepare_mha_for_tp(
624626
m.num_heads = m.num_heads // tp_size
625627
m.num_kv_heads = m.num_kv_heads // tp_size
626628
m.embed_dim = m.embed_dim // tp_size
629+
630+
if is_fusion_model:
631+
model.decoder = decoder
627632
return model

0 commit comments

Comments
 (0)