diff --git a/xtuner/v1/optim/muon.py b/xtuner/v1/optim/muon.py index 5a8a8200c..9b0bc7666 100644 --- a/xtuner/v1/optim/muon.py +++ b/xtuner/v1/optim/muon.py @@ -584,69 +584,150 @@ def muon_update_batch_async( assert process_group is not None, "process_group must be provided for sharded DTensors" assert isinstance(X[0], DTensor), "X should contain DTensors" assert not isinstance(U[0], DTensor), "U should contain local shards" - assert X[0].size(shard_dim) % world_size == 0, ( - f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." - ) - # Pack the list of shards into a single contiguous tensor - # U is currently [Shard_for_Rank0, Shard_for_Rank1, ...] - # Stack creates shape: (World_Size, *Shard_Shape) - U_packed = torch.stack(U) - - # Allocate buffer to receive parts of the "Single Matrix" - # Shape: (World_Size, *Shard_Shape) - single_matrix_parts = torch.empty_like(U_packed) - - # Perform optimized All-to-All - # This sends one large contiguous buffer instead of many small ones - work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True) - yield - work.wait() - - # Reconstruct the full matrix - # single_matrix_parts has shape (World_Size, D0, D1...) - if shard_dim == 0: - # Optimization: If sharded on dim 0, we can simply flatten the batch dim - # to reconstruct the full matrix. This is a Zero-Copy View. - single_matrix = single_matrix_parts.flatten(0, 1) - else: - # General case (e.g., Col-wise sharding): We must concatenate along shard_dim. - # This requires a memory copy. - single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim) + global_shard_dim_size = X[0].size(shard_dim) - # 5. Perform Newton-Schulz Orthogonalization - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) + if global_shard_dim_size >= world_size and global_shard_dim_size % world_size == 0: + # Standard path: use all-to-all for evenly sharded tensors + # Pack the list of shards into a single contiguous tensor + # U is currently [Shard_for_Rank0, Shard_for_Rank1, ...] + # Stack creates shape: (World_Size, *Shard_Shape) + U_packed = torch.stack(U) - # Prepare to scatter results back - if shard_dim == 0: - # Optimization: View back to (World_Size, Shard_Size, ...) - # This is a Zero-Copy View. - single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:]) - else: - # General case: Split back into chunks and stack them. - # We use stack to ensure the output is contiguous (World_Size, ...) for NCCL - single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim)) + # Allocate buffer to receive parts of the "Single Matrix" + # Shape: (World_Size, *Shard_Shape) + single_matrix_parts = torch.empty_like(U_packed) + + # Perform optimized All-to-All + # This sends one large contiguous buffer instead of many small ones + work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True) + yield + work.wait() - # Ensure contiguity is preserved (crucial for NCCL) - if not single_matrix_shards_packed.is_contiguous(): - single_matrix_shards_packed = single_matrix_shards_packed.contiguous() + # Reconstruct the full matrix + # single_matrix_parts has shape (World_Size, D0, D1...) + if shard_dim == 0: + # Optimization: If sharded on dim 0, we can simply flatten the batch dim + # to reconstruct the full matrix. This is a Zero-Copy View. + single_matrix = single_matrix_parts.flatten(0, 1) + else: + # General case (e.g., Col-wise sharding): We must concatenate along shard_dim. + # This requires a memory copy. + single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim) + + # 5. Perform Newton-Schulz Orthogonalization + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + # Prepare to scatter results back + if shard_dim == 0: + # Optimization: View back to (World_Size, Shard_Size, ...) + # This is a Zero-Copy View. + single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:]) + else: + # General case: Split back into chunks and stack them. + # We use stack to ensure the output is contiguous (World_Size, ...) for NCCL + single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim)) - # Allocate buffer for receiving updated gradients - U_packed_back = torch.empty_like(single_matrix_shards_packed) + # Ensure contiguity is preserved (crucial for NCCL) + if not single_matrix_shards_packed.is_contiguous(): + single_matrix_shards_packed = single_matrix_shards_packed.contiguous() - # Perform optimized All-to-All (Scatter back) - work = dist.all_to_all_single(U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True) - yield - work.wait() + # Allocate buffer for receiving updated gradients + U_packed_back = torch.empty_like(single_matrix_shards_packed) - # Unpack back to list form for the post-processing function - # unbind(0) is a view operation (slicing) - U = list(U_packed_back.unbind(0)) + # Perform optimized All-to-All (Scatter back) + work = dist.all_to_all_single( + U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True + ) + yield + work.wait() + + # Unpack back to list form for the post-processing function + # unbind(0) is a view operation (slicing) + U = list(U_packed_back.unbind(0)) + + else: + # Small matrix path: when global_shard_dim_size < world_size or not evenly divisible, + # use all-gather + redundant orthogonalization on each rank. + # This handles uneven sharding where some ranks may have empty shards. + + # Calculate padded shard size (ceil division) so all ranks have same-sized tensors + padded_shard_size = (global_shard_dim_size + world_size - 1) // world_size + + # Pad all local shards to the same size for uniform all-gather + U_padded = [] + for u in U: + current_size = u.size(shard_dim) + if current_size < padded_shard_size: + pad_size = padded_shard_size - current_size + pad_shape = list(u.shape) + pad_shape[shard_dim] = pad_size + padding = torch.zeros(pad_shape, dtype=u.dtype, device=u.device) + u_padded = torch.cat([u, padding], dim=shard_dim) + else: + u_padded = u + U_padded.append(u_padded) + + # Stack into single tensor: (world_size, padded_shard_size, ...) + U_packed = torch.stack(U_padded) + + # All-gather to get all shards from all ranks + # Output shape: (world_size, world_size, padded_shard_size, ...) + gathered = torch.empty((world_size,) + U_packed.shape, dtype=U_packed.dtype, device=U_packed.device) + work = dist.all_gather_into_tensor(gathered, U_packed, group=process_group, async_op=True) + yield + work.wait() + + # Compute true local sizes for each rank using DTensor's sharding logic: + # DTensor divides into ceil-sized chunks, with the last chunk getting the remainder + local_sizes = [] + chunk_size = padded_shard_size # This is ceil(global_size / world_size) + for r in range(world_size): + start = r * chunk_size + end = min((r + 1) * chunk_size, global_shard_dim_size) + local_sizes.append(max(0, end - start)) + + # Reconstruct and orthogonalize all full matrices (redundant computation on each rank) + # gathered[r, i] = Rank r's padded local shard of matrix i + full_matrices = [] + for i in range(world_size): + shards = [] + for r in range(world_size): + shard = gathered[r, i] # (padded_shard_size, ...) + true_size = local_sizes[r] + if true_size > 0: + # Unpad: take only the true local size + shard = shard.narrow(shard_dim, 0, true_size) + shards.append(shard) + # Concatenate to get full matrix + full_matrix = torch.cat(shards, dim=shard_dim) + # Orthogonalize + full_matrix = muon_update_newton_schulz( + full_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + full_matrices.append(full_matrix) + + # Extract the shard for current rank from each orthogonalized matrix + offset = sum(local_sizes[:device_rank]) + my_size = local_sizes[device_rank] + U = [] + for fm in full_matrices: + if my_size > 0: + shard = fm.narrow(shard_dim, offset, my_size) + else: + # Empty shard for this rank + shape = list(fm.shape) + shape[shard_dim] = 0 + shard = torch.empty(shape, dtype=fm.dtype, device=fm.device) + U.append(shard) else: # Matrices are not sharded, so we can directly orthogonalize