Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 137 additions & 56 deletions xtuner/v1/optim/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,69 +584,150 @@ def muon_update_batch_async(
assert process_group is not None, "process_group must be provided for sharded DTensors"
assert isinstance(X[0], DTensor), "X should contain DTensors"
assert not isinstance(U[0], DTensor), "U should contain local shards"
assert X[0].size(shard_dim) % world_size == 0, (
f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}."
)

# Pack the list of shards into a single contiguous tensor
# U is currently [Shard_for_Rank0, Shard_for_Rank1, ...]
# Stack creates shape: (World_Size, *Shard_Shape)
U_packed = torch.stack(U)

# Allocate buffer to receive parts of the "Single Matrix"
# Shape: (World_Size, *Shard_Shape)
single_matrix_parts = torch.empty_like(U_packed)

# Perform optimized All-to-All
# This sends one large contiguous buffer instead of many small ones
work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True)
yield
work.wait()

# Reconstruct the full matrix
# single_matrix_parts has shape (World_Size, D0, D1...)
if shard_dim == 0:
# Optimization: If sharded on dim 0, we can simply flatten the batch dim
# to reconstruct the full matrix. This is a Zero-Copy View.
single_matrix = single_matrix_parts.flatten(0, 1)
else:
# General case (e.g., Col-wise sharding): We must concatenate along shard_dim.
# This requires a memory copy.
single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim)
global_shard_dim_size = X[0].size(shard_dim)

# 5. Perform Newton-Schulz Orthogonalization
single_matrix = muon_update_newton_schulz(
single_matrix,
newton_schulz_func=newton_schulz_func,
flatten=flatten,
epsilon=epsilon,
)
if global_shard_dim_size >= world_size and global_shard_dim_size % world_size == 0:
# Standard path: use all-to-all for evenly sharded tensors
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: The global_shard_dim_size >= world_size check is redundant. For any positive global_shard_dim_size, if global_shard_dim_size % world_size == 0, then global_shard_dim_size >= world_size is guaranteed (the smallest positive value satisfying divisibility is world_size itself). Simplifying to just global_shard_dim_size % world_size == 0 would be cleaner.

That said, keeping it doesn't hurt correctness — it just makes the intent more explicit at the cost of a slightly longer condition.

# Pack the list of shards into a single contiguous tensor
# U is currently [Shard_for_Rank0, Shard_for_Rank1, ...]
# Stack creates shape: (World_Size, *Shard_Shape)
U_packed = torch.stack(U)

# Prepare to scatter results back
if shard_dim == 0:
# Optimization: View back to (World_Size, Shard_Size, ...)
# This is a Zero-Copy View.
single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:])
else:
# General case: Split back into chunks and stack them.
# We use stack to ensure the output is contiguous (World_Size, ...) for NCCL
single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim))
# Allocate buffer to receive parts of the "Single Matrix"
# Shape: (World_Size, *Shard_Shape)
single_matrix_parts = torch.empty_like(U_packed)

# Perform optimized All-to-All
# This sends one large contiguous buffer instead of many small ones
work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True)
yield
work.wait()

# Ensure contiguity is preserved (crucial for NCCL)
if not single_matrix_shards_packed.is_contiguous():
single_matrix_shards_packed = single_matrix_shards_packed.contiguous()
# Reconstruct the full matrix
# single_matrix_parts has shape (World_Size, D0, D1...)
if shard_dim == 0:
# Optimization: If sharded on dim 0, we can simply flatten the batch dim
# to reconstruct the full matrix. This is a Zero-Copy View.
single_matrix = single_matrix_parts.flatten(0, 1)
else:
# General case (e.g., Col-wise sharding): We must concatenate along shard_dim.
# This requires a memory copy.
single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim)

# 5. Perform Newton-Schulz Orthogonalization
single_matrix = muon_update_newton_schulz(
single_matrix,
newton_schulz_func=newton_schulz_func,
flatten=flatten,
epsilon=epsilon,
)

# Prepare to scatter results back
if shard_dim == 0:
# Optimization: View back to (World_Size, Shard_Size, ...)
# This is a Zero-Copy View.
single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:])
else:
# General case: Split back into chunks and stack them.
Comment on lines +627 to +632
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Note: The gathered tensor has shape (world_size, world_size, padded_shard_size, ...) — a W² factor in memory vs the standard path's W factor. Since this only triggers for small matrices (dim0 < world_size or non-divisible), this should be fine in practice. But it might be worth adding a brief comment explaining the memory trade-off, so future readers understand why this path isn't used universally.

# We use stack to ensure the output is contiguous (World_Size, ...) for NCCL
single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim))

# Allocate buffer for receiving updated gradients
U_packed_back = torch.empty_like(single_matrix_shards_packed)
# Ensure contiguity is preserved (crucial for NCCL)
if not single_matrix_shards_packed.is_contiguous():
single_matrix_shards_packed = single_matrix_shards_packed.contiguous()

# Perform optimized All-to-All (Scatter back)
work = dist.all_to_all_single(U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True)
yield
work.wait()
# Allocate buffer for receiving updated gradients
U_packed_back = torch.empty_like(single_matrix_shards_packed)

# Unpack back to list form for the post-processing function
# unbind(0) is a view operation (slicing)
U = list(U_packed_back.unbind(0))
# Perform optimized All-to-All (Scatter back)
work = dist.all_to_all_single(
U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True
)
yield
work.wait()

Comment on lines +637 to +649
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: The local_sizes computation assumes DTensor uses ceil-division chunking (chunk_size = ceil(N/W), last rank gets the remainder). This matches current torch.distributed._tensor behavior, but it's an implicit coupling — if DTensor's sharding strategy ever changes, this will silently produce incorrect results.

Consider either:

  1. Adding a comment citing the specific PyTorch DTensor sharding contract, or
  2. Computing local sizes from the actual U tensors (e.g., querying sizes from the DTensor placements) to make this resilient to future changes.

Not blocking, since the current assumption is correct today.

# Unpack back to list form for the post-processing function
# unbind(0) is a view operation (slicing)
U = list(U_packed_back.unbind(0))

else:
# Small matrix path: when global_shard_dim_size < world_size or not evenly divisible,
# use all-gather + redundant orthogonalization on each rank.
# This handles uneven sharding where some ranks may have empty shards.

# Calculate padded shard size (ceil division) so all ranks have same-sized tensors
padded_shard_size = (global_shard_dim_size + world_size - 1) // world_size

# Pad all local shards to the same size for uniform all-gather
U_padded = []
for u in U:
current_size = u.size(shard_dim)
if current_size < padded_shard_size:
pad_size = padded_shard_size - current_size
pad_shape = list(u.shape)
pad_shape[shard_dim] = pad_size
padding = torch.zeros(pad_shape, dtype=u.dtype, device=u.device)
u_padded = torch.cat([u, padding], dim=shard_dim)
else:
u_padded = u
U_padded.append(u_padded)

# Stack into single tensor: (world_size, padded_shard_size, ...)
U_packed = torch.stack(U_padded)

# All-gather to get all shards from all ranks
# Output shape: (world_size, world_size, padded_shard_size, ...)
gathered = torch.empty((world_size,) + U_packed.shape, dtype=U_packed.dtype, device=U_packed.device)
work = dist.all_gather_into_tensor(gathered, U_packed, group=process_group, async_op=True)
yield
work.wait()

# Compute true local sizes for each rank using DTensor's sharding logic:
# DTensor divides into ceil-sized chunks, with the last chunk getting the remainder
local_sizes = []
chunk_size = padded_shard_size # This is ceil(global_size / world_size)
for r in range(world_size):
start = r * chunk_size
end = min((r + 1) * chunk_size, global_shard_dim_size)
local_sizes.append(max(0, end - start))

# Reconstruct and orthogonalize all full matrices (redundant computation on each rank)
# gathered[r, i] = Rank r's padded local shard of matrix i
full_matrices = []
for i in range(world_size):
shards = []
for r in range(world_size):
shard = gathered[r, i] # (padded_shard_size, ...)
true_size = local_sizes[r]
if true_size > 0:
# Unpad: take only the true local size
shard = shard.narrow(shard_dim, 0, true_size)
shards.append(shard)
# Concatenate to get full matrix
full_matrix = torch.cat(shards, dim=shard_dim)
# Orthogonalize
full_matrix = muon_update_newton_schulz(
full_matrix,
newton_schulz_func=newton_schulz_func,
flatten=flatten,
epsilon=epsilon,
)
full_matrices.append(full_matrix)

# Extract the shard for current rank from each orthogonalized matrix
offset = sum(local_sizes[:device_rank])
my_size = local_sizes[device_rank]
U = []
for fm in full_matrices:
if my_size > 0:
shard = fm.narrow(shard_dim, offset, my_size)
else:
# Empty shard for this rank
shape = list(fm.shape)
shape[shard_dim] = 0
shard = torch.empty(shape, dtype=fm.dtype, device=fm.device)
U.append(shard)

else:
# Matrices are not sharded, so we can directly orthogonalize
Expand Down
Loading