Skip to content

Add support for loss parallel #1546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,48 @@ def input_fn():
"dp_shard": Shard(0),
"tp": Replicate(),
}
# only used if loss parallel is enabled
possible_output_shardings = {
# maps relative to mesh dim names used in torchtitan
"dp_shard": Shard(0),
"tp": Shard(2),
}
assert all(
name in possible_input_shardings for name in world_mesh.mesh_dim_names
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
x_sharding = tuple(
possible_input_shardings[name] for name in world_mesh.mesh_dim_names
)
out_sharding = x_sharding
if parallel_dims.loss_parallel_enabled:
out_sharding = tuple(
possible_output_shardings[name]
for name in world_mesh.mesh_dim_names
if name != "dp_replicate"
)
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([x_sharding])
autop.add_output_constraints([out_sharding])
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
logger.info(f"AutoParallel took {t1 - t0} seconds")
parallel_mod = autop.apply_placement(sharding_placement)

if parallel_dims.loss_parallel_enabled:

# current PyTorch's implementation of loss parallel assumes
# that the DTensor has a 1d device mesh. This is not true
# in our case, but we can work around it by adding
# casting the output to a DTensor on a 1d device mesh.
# We should just use AutoParallel to do this for us, but
# it would require putting the loss inside the model as well
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that overall we should just put the loss in the model, but I like the approach here for now because it's useful to be as structurally similar to torchtitan as possible for drop-in purposes

def _return_as_dtensor_for_loss_parallel(module, args, output):
return torch.distributed.tensor.DTensor.from_local(
output, world_mesh["tp"], (Shard(2),)
)

# not keeping a reference to the hook, don't plan on
# removing it at any point
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)

return parallel_mod