-
Notifications
You must be signed in to change notification settings - Fork 410
[Fix][muon] fix assertion err when dim0 size < world_size #1611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -584,69 +584,150 @@ def muon_update_batch_async( | |
| assert process_group is not None, "process_group must be provided for sharded DTensors" | ||
| assert isinstance(X[0], DTensor), "X should contain DTensors" | ||
| assert not isinstance(U[0], DTensor), "U should contain local shards" | ||
| assert X[0].size(shard_dim) % world_size == 0, ( | ||
| f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." | ||
| ) | ||
|
|
||
| # Pack the list of shards into a single contiguous tensor | ||
| # U is currently [Shard_for_Rank0, Shard_for_Rank1, ...] | ||
| # Stack creates shape: (World_Size, *Shard_Shape) | ||
| U_packed = torch.stack(U) | ||
|
|
||
| # Allocate buffer to receive parts of the "Single Matrix" | ||
| # Shape: (World_Size, *Shard_Shape) | ||
| single_matrix_parts = torch.empty_like(U_packed) | ||
|
|
||
| # Perform optimized All-to-All | ||
| # This sends one large contiguous buffer instead of many small ones | ||
| work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True) | ||
| yield | ||
| work.wait() | ||
|
|
||
| # Reconstruct the full matrix | ||
| # single_matrix_parts has shape (World_Size, D0, D1...) | ||
| if shard_dim == 0: | ||
| # Optimization: If sharded on dim 0, we can simply flatten the batch dim | ||
| # to reconstruct the full matrix. This is a Zero-Copy View. | ||
| single_matrix = single_matrix_parts.flatten(0, 1) | ||
| else: | ||
| # General case (e.g., Col-wise sharding): We must concatenate along shard_dim. | ||
| # This requires a memory copy. | ||
| single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim) | ||
| global_shard_dim_size = X[0].size(shard_dim) | ||
|
|
||
| # 5. Perform Newton-Schulz Orthogonalization | ||
| single_matrix = muon_update_newton_schulz( | ||
| single_matrix, | ||
| newton_schulz_func=newton_schulz_func, | ||
| flatten=flatten, | ||
| epsilon=epsilon, | ||
| ) | ||
| if global_shard_dim_size >= world_size and global_shard_dim_size % world_size == 0: | ||
| # Standard path: use all-to-all for evenly sharded tensors | ||
| # Pack the list of shards into a single contiguous tensor | ||
| # U is currently [Shard_for_Rank0, Shard_for_Rank1, ...] | ||
| # Stack creates shape: (World_Size, *Shard_Shape) | ||
| U_packed = torch.stack(U) | ||
|
|
||
| # Prepare to scatter results back | ||
| if shard_dim == 0: | ||
| # Optimization: View back to (World_Size, Shard_Size, ...) | ||
| # This is a Zero-Copy View. | ||
| single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:]) | ||
| else: | ||
| # General case: Split back into chunks and stack them. | ||
| # We use stack to ensure the output is contiguous (World_Size, ...) for NCCL | ||
| single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim)) | ||
| # Allocate buffer to receive parts of the "Single Matrix" | ||
| # Shape: (World_Size, *Shard_Shape) | ||
| single_matrix_parts = torch.empty_like(U_packed) | ||
|
|
||
| # Perform optimized All-to-All | ||
| # This sends one large contiguous buffer instead of many small ones | ||
| work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True) | ||
| yield | ||
| work.wait() | ||
|
|
||
| # Ensure contiguity is preserved (crucial for NCCL) | ||
| if not single_matrix_shards_packed.is_contiguous(): | ||
| single_matrix_shards_packed = single_matrix_shards_packed.contiguous() | ||
| # Reconstruct the full matrix | ||
| # single_matrix_parts has shape (World_Size, D0, D1...) | ||
| if shard_dim == 0: | ||
| # Optimization: If sharded on dim 0, we can simply flatten the batch dim | ||
| # to reconstruct the full matrix. This is a Zero-Copy View. | ||
| single_matrix = single_matrix_parts.flatten(0, 1) | ||
| else: | ||
| # General case (e.g., Col-wise sharding): We must concatenate along shard_dim. | ||
| # This requires a memory copy. | ||
| single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim) | ||
|
|
||
| # 5. Perform Newton-Schulz Orthogonalization | ||
| single_matrix = muon_update_newton_schulz( | ||
| single_matrix, | ||
| newton_schulz_func=newton_schulz_func, | ||
| flatten=flatten, | ||
| epsilon=epsilon, | ||
| ) | ||
|
|
||
| # Prepare to scatter results back | ||
| if shard_dim == 0: | ||
| # Optimization: View back to (World_Size, Shard_Size, ...) | ||
| # This is a Zero-Copy View. | ||
| single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:]) | ||
| else: | ||
| # General case: Split back into chunks and stack them. | ||
|
Comment on lines
+627
to
+632
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Note: The |
||
| # We use stack to ensure the output is contiguous (World_Size, ...) for NCCL | ||
| single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim)) | ||
|
|
||
| # Allocate buffer for receiving updated gradients | ||
| U_packed_back = torch.empty_like(single_matrix_shards_packed) | ||
| # Ensure contiguity is preserved (crucial for NCCL) | ||
| if not single_matrix_shards_packed.is_contiguous(): | ||
| single_matrix_shards_packed = single_matrix_shards_packed.contiguous() | ||
|
|
||
| # Perform optimized All-to-All (Scatter back) | ||
| work = dist.all_to_all_single(U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True) | ||
| yield | ||
| work.wait() | ||
| # Allocate buffer for receiving updated gradients | ||
| U_packed_back = torch.empty_like(single_matrix_shards_packed) | ||
|
|
||
| # Unpack back to list form for the post-processing function | ||
| # unbind(0) is a view operation (slicing) | ||
| U = list(U_packed_back.unbind(0)) | ||
| # Perform optimized All-to-All (Scatter back) | ||
| work = dist.all_to_all_single( | ||
| U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True | ||
| ) | ||
| yield | ||
| work.wait() | ||
|
|
||
|
Comment on lines
+637
to
+649
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning: The Consider either:
Not blocking, since the current assumption is correct today. |
||
| # Unpack back to list form for the post-processing function | ||
| # unbind(0) is a view operation (slicing) | ||
| U = list(U_packed_back.unbind(0)) | ||
|
|
||
| else: | ||
| # Small matrix path: when global_shard_dim_size < world_size or not evenly divisible, | ||
| # use all-gather + redundant orthogonalization on each rank. | ||
| # This handles uneven sharding where some ranks may have empty shards. | ||
|
|
||
| # Calculate padded shard size (ceil division) so all ranks have same-sized tensors | ||
| padded_shard_size = (global_shard_dim_size + world_size - 1) // world_size | ||
|
|
||
| # Pad all local shards to the same size for uniform all-gather | ||
| U_padded = [] | ||
| for u in U: | ||
| current_size = u.size(shard_dim) | ||
| if current_size < padded_shard_size: | ||
| pad_size = padded_shard_size - current_size | ||
| pad_shape = list(u.shape) | ||
| pad_shape[shard_dim] = pad_size | ||
| padding = torch.zeros(pad_shape, dtype=u.dtype, device=u.device) | ||
| u_padded = torch.cat([u, padding], dim=shard_dim) | ||
| else: | ||
| u_padded = u | ||
| U_padded.append(u_padded) | ||
|
|
||
| # Stack into single tensor: (world_size, padded_shard_size, ...) | ||
| U_packed = torch.stack(U_padded) | ||
|
|
||
| # All-gather to get all shards from all ranks | ||
| # Output shape: (world_size, world_size, padded_shard_size, ...) | ||
| gathered = torch.empty((world_size,) + U_packed.shape, dtype=U_packed.dtype, device=U_packed.device) | ||
| work = dist.all_gather_into_tensor(gathered, U_packed, group=process_group, async_op=True) | ||
| yield | ||
| work.wait() | ||
|
|
||
| # Compute true local sizes for each rank using DTensor's sharding logic: | ||
| # DTensor divides into ceil-sized chunks, with the last chunk getting the remainder | ||
| local_sizes = [] | ||
| chunk_size = padded_shard_size # This is ceil(global_size / world_size) | ||
| for r in range(world_size): | ||
| start = r * chunk_size | ||
| end = min((r + 1) * chunk_size, global_shard_dim_size) | ||
| local_sizes.append(max(0, end - start)) | ||
|
|
||
| # Reconstruct and orthogonalize all full matrices (redundant computation on each rank) | ||
| # gathered[r, i] = Rank r's padded local shard of matrix i | ||
| full_matrices = [] | ||
| for i in range(world_size): | ||
| shards = [] | ||
| for r in range(world_size): | ||
| shard = gathered[r, i] # (padded_shard_size, ...) | ||
| true_size = local_sizes[r] | ||
| if true_size > 0: | ||
| # Unpad: take only the true local size | ||
| shard = shard.narrow(shard_dim, 0, true_size) | ||
| shards.append(shard) | ||
| # Concatenate to get full matrix | ||
| full_matrix = torch.cat(shards, dim=shard_dim) | ||
| # Orthogonalize | ||
| full_matrix = muon_update_newton_schulz( | ||
| full_matrix, | ||
| newton_schulz_func=newton_schulz_func, | ||
| flatten=flatten, | ||
| epsilon=epsilon, | ||
| ) | ||
| full_matrices.append(full_matrix) | ||
|
|
||
| # Extract the shard for current rank from each orthogonalized matrix | ||
| offset = sum(local_sizes[:device_rank]) | ||
| my_size = local_sizes[device_rank] | ||
| U = [] | ||
| for fm in full_matrices: | ||
| if my_size > 0: | ||
| shard = fm.narrow(shard_dim, offset, my_size) | ||
| else: | ||
| # Empty shard for this rank | ||
| shape = list(fm.shape) | ||
| shape[shard_dim] = 0 | ||
| shard = torch.empty(shape, dtype=fm.dtype, device=fm.device) | ||
| U.append(shard) | ||
|
|
||
| else: | ||
| # Matrices are not sharded, so we can directly orthogonalize | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Nit: The
global_shard_dim_size >= world_sizecheck is redundant. For any positiveglobal_shard_dim_size, ifglobal_shard_dim_size % world_size == 0, thenglobal_shard_dim_size >= world_sizeis guaranteed (the smallest positive value satisfying divisibility isworld_sizeitself). Simplifying to justglobal_shard_dim_size % world_size == 0would be cleaner.That said, keeping it doesn't hurt correctness — it just makes the intent more explicit at the cost of a slightly longer condition.