-
Notifications
You must be signed in to change notification settings - Fork 410
[Fix] Muon optimizer per-expert orthogonalization for MoE models #1582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2ce8f52
cd0fe5b
3586707
72e2a42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # AGENTS.md - Project Guidelines for AI Assistants | ||
|
|
||
| This file contains project-specific guidelines and conventions for AI assistants working on this codebase. | ||
|
|
||
| ## Git Commit Message Guidelines | ||
|
|
||
| ### Format | ||
|
|
||
| ``` | ||
| [<Type>] <Short summary> | ||
|
|
||
| <Long description explaining what and why (not how)> | ||
|
|
||
| - Bullet point for specific changes | ||
| - Another bullet point | ||
| ``` | ||
|
|
||
| ### Types | ||
|
|
||
| - `[Feature]` - New feature | ||
| - `[Fix]` - Bug fix | ||
| - `[Refactor]` - Code refactoring | ||
| - `[Docs]` - Documentation changes | ||
| - `[Test]` - Test changes | ||
| - `[Chore]` - Build/tooling changes | ||
|
|
||
| ### Guidelines | ||
|
|
||
| 1. **Short summary**: Concise description of the change (50 chars or less) | ||
| 2. **Long description**: Explain **what** changed and **why**, not **how** | ||
| 3. **No bullet points**: Do not list specific changes in commit message | ||
| 4. **No file lists**: Do not include file names or "Files modified:" section | ||
| 5. **Keep it brief**: Only high-level functional description, details go to PR description | ||
|
|
||
| ### Example | ||
|
|
||
| ``` | ||
| [Fix] Muon optimizer per-expert orthogonalization for MoE models | ||
|
|
||
| Fix Muon optimizer to apply orthogonalization per expert matrix instead of | ||
| to the concatenated large matrix for MoE models. | ||
| ``` | ||
|
|
||
| ## PR Description Guidelines | ||
|
|
||
| The PR description should contain: | ||
|
|
||
| 1. **Summary**: Brief overview of the changes | ||
| 2. **Motivation**: Why this change is needed | ||
| 3. **Changes**: Detailed list of what changed | ||
| 4. **Files modified**: List of files changed | ||
| 5. **Testing**: How the changes were tested | ||
|
|
||
| ## Code Style | ||
|
|
||
| - Follow existing code style in the project | ||
| - Add type hints for new functions | ||
| - Add docstrings for public functions and classes | ||
| - Keep functions focused and small | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,247 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| """Test Muon optimizer Newton-Schulz functions with/without torch.compile. | ||
|
|
||
| Test shapes are based on Qwen3-30B-A3B model config: | ||
| - hidden_size: 2048 | ||
| - num_experts: 128 | ||
| - moe_intermediate_size: 768 | ||
| - intermediate_size: 6144 (for shared expert) | ||
|
|
||
| MoE expert weight shapes: | ||
| - w1/w3: (num_experts * moe_intermediate_size, hidden_size) = (98304, 2048) | ||
| per expert: (768, 2048) | ||
| - w2: (hidden_size, num_experts * moe_intermediate_size) = (2048, 98304) | ||
| per expert: (2048, 768) | ||
|
|
||
| For testing, we use scaled-down versions to keep tests fast while maintaining | ||
| representative shapes. | ||
|
|
||
| ================================================================================ | ||
| IMPORTANT: DTensor Compatibility Note | ||
| ================================================================================ | ||
|
|
||
| The zeropower_via_newtonschulz5 function supports DTensor input, but with a | ||
| known limitation when M > N (e.g., w2 weights where hidden_size > moe_intermediate_size). | ||
|
|
||
| Root Cause Analysis (verified by /tmp/test_dtensor_root_cause_detailed.py): | ||
| --------------------------------------------------------------------------- | ||
| When M > N, the Newton-Schulz algorithm transposes the input matrix: | ||
| X = G.view(1, M, N).mT # becomes (1, N, M) | ||
|
|
||
| For a DTensor sharded on dim 0 (M dimension): | ||
| 1. After view(1, M, N): placements become Shard(dim=1) | ||
| 2. After mT: placements become Shard(dim=2) # the M dimension moves to dim 2 | ||
| 3. X @ X.mT produces Partial(sum) DTensor # contraction dim is sharded | ||
| 4. Partial values are not correctly reduced in subsequent operations | ||
| 5. Error accumulates across 5 Newton-Schulz iterations: | ||
| Iter 1: X max ~0.016 | ||
| Iter 2: X max ~0.060 | ||
| Iter 3: X max ~0.099 | ||
| Iter 4: X max ~0.29 | ||
| Iter 5: X max ~47.5 (EXPLOSION!) | ||
| 6. Final result is completely wrong (e.g., 0.1 -> 47.5) | ||
|
|
||
| Verification Results: | ||
| - M < N (w1/w3): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly | ||
| - M > N (w2): ✗ FAIL - A @ A.mT produces Partial(sum), results explode | ||
| - M = N (square): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly | ||
|
|
||
| Workaround: | ||
| For DTensor with M > N (w2 weights), convert to local tensor: | ||
| result = zeropower_via_newtonschulz5(dtensor.to_local(), num_experts=1) | ||
|
|
||
| Note: | ||
| This is NOT a torch.compile issue. The same problem occurs with or without | ||
| torch.compile. It's a fundamental limitation of DTensor's Partial placement | ||
| handling in complex matrix operation chains. | ||
|
|
||
| newton_schulz_triton: | ||
| Does not support DTensor at all due to direct Triton kernel usage. | ||
| Must use .to_local() to convert before calling. | ||
|
Comment on lines
+1
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning: The module docstring contains ~40 lines of debug analysis (DTensor root cause, iteration-by-iteration error growth, references to Consider trimming this to a brief summary of what's tested and why, with a link to an issue or doc if the DTensor limitation needs to be tracked. The reference to |
||
| ================================================================================ | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| # Skip all tests if CUDA is not available | ||
| pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
|
|
||
|
|
||
| class TestNewtonSchulzCompile: | ||
| """Test Newton-Schulz functions with and without torch.compile.""" | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup(self): | ||
| """Setup test fixtures.""" | ||
| self.device = "cuda" | ||
| self.dtype = torch.bfloat16 | ||
| self.epsilon = 1e-7 | ||
| self.tolerance = 1e-3 # Tolerance for bfloat16 comparison | ||
|
|
||
| def _create_test_matrix(self, num_experts, M, N): | ||
| """Create a test matrix with given dimensions.""" | ||
| shape = (num_experts * M, N) | ||
| return torch.randn(shape, device=self.device, dtype=torch.float32) | ||
|
|
||
| def test_zeropower_via_newtonschulz5_compile(self): | ||
| """Test muon.zeropower_via_newtonschulz5 with/without compile. | ||
|
|
||
| Test cases based on Qwen3 MoE architecture (hidden_size=2048, num_experts=128): | ||
| - Non-MoE: (6144, 2048) and (2048, 6144) for shared experts | ||
| - MoE w1/w3: (128 * 768, 2048) per expert (768, 2048) | ||
| - MoE w2: (2048, 128 * 768) per expert (2048, 768) | ||
| """ | ||
| from xtuner.v1.optim.muon import zeropower_via_newtonschulz5 | ||
|
|
||
| # Scaled-down test cases based on Qwen3 MoE config | ||
| test_cases = [ | ||
| # Non-MoE cases (shared expert-like) | ||
| (1, 1536, 512, "shared_expert_w1"), # (1536, 512) scaled from (6144, 2048) | ||
| (1, 512, 1536, "shared_expert_w2"), # (512, 1536) scaled from (2048, 6144) | ||
| # MoE cases - w1/w3 like (M < N) | ||
| (8, 192, 512, "moe_w1_small"), # per expert: (192, 512) scaled from (768, 2048) | ||
| (16, 192, 512, "moe_w1_medium"), # 16 experts | ||
| # MoE cases - w2 like (M > N) | ||
| (8, 512, 192, "moe_w2_small"), # per expert: (512, 192) scaled from (2048, 768) | ||
| (16, 512, 192, "moe_w2_medium"), # 16 experts | ||
| # Square cases | ||
| (1, 512, 512, "square_regular"), | ||
| (4, 256, 256, "square_moe"), | ||
| ] | ||
|
|
||
| for num_experts, M, N, name in test_cases: | ||
| G = self._create_test_matrix(num_experts, M, N) | ||
|
|
||
| # Without compile | ||
| result_no_compile = zeropower_via_newtonschulz5( | ||
| G, epsilon=self.epsilon, num_experts=num_experts | ||
| ) | ||
|
|
||
| # With compile | ||
| compiled_fn = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) | ||
| result_compile = compiled_fn(G, epsilon=self.epsilon, num_experts=num_experts) | ||
|
|
||
| # Compare results | ||
| max_diff = (result_no_compile - result_compile).abs().max().item() | ||
| assert max_diff < self.tolerance, ( | ||
| f"{name} (num_experts={num_experts}, M={M}, N={N}): " | ||
| f"max_diff={max_diff} >= {self.tolerance}" | ||
| ) | ||
|
|
||
| def test_newton_schulz_triton(self): | ||
| """Test newton_schulz_triton (Triton kernel, no torch.compile). | ||
|
|
||
| Note: Triton kernel is not compatible with torch.compile, so we only test | ||
| without compile and verify basic correctness. | ||
| """ | ||
| from xtuner.v1.optim.newton_schulz_triton import newton_schulz_triton | ||
|
|
||
| # Scaled-down test cases based on Qwen3 MoE config | ||
| test_cases = [ | ||
| # Non-MoE cases (shared expert-like) | ||
| (1, 1536, 512, "shared_expert_w1"), # (1536, 512) | ||
| (1, 512, 1536, "shared_expert_w2"), # (512, 1536) | ||
| # MoE cases - w1/w3 like (M < N) | ||
| (8, 192, 512, "moe_w1_small"), # 8 experts, each (192, 512) | ||
| (16, 192, 512, "moe_w1_medium"), # 16 experts | ||
| # MoE cases - w2 like (M > N) | ||
| (8, 512, 192, "moe_w2_small"), # 8 experts, each (512, 192) | ||
| (16, 512, 192, "moe_w2_medium"), # 16 experts | ||
| # Square cases | ||
| (1, 512, 512, "square_regular"), | ||
| (4, 256, 256, "square_moe"), | ||
| ] | ||
|
|
||
| for num_experts, M, N, name in test_cases: | ||
| G = self._create_test_matrix(num_experts, M, N) | ||
|
|
||
| # Test without compile (Triton kernel doesn't support compile) | ||
| result = newton_schulz_triton(G, epsilon=self.epsilon, num_experts=num_experts) | ||
|
|
||
| # Basic sanity check: output should have correct shape | ||
| assert result.shape == G.shape, f"{name}: output shape mismatch" | ||
|
|
||
| # Output should not be all zeros or contain NaN/Inf | ||
| assert not torch.isnan(result).any(), f"{name}: output contains NaN" | ||
| assert not torch.isinf(result).any(), f"{name}: output contains Inf" | ||
| assert result.abs().max() > 0, f"{name}: output is all zeros" | ||
|
|
||
| def test_transpose_case_compile(self): | ||
| """Test matrices where rows > cols (transpose case) with compile. | ||
|
|
||
| Based on Qwen3 MoE w2 shape: (hidden_size, num_experts * moe_intermediate_size) | ||
| """ | ||
| from xtuner.v1.optim.muon import zeropower_via_newtonschulz5 | ||
|
|
||
| test_cases = [ | ||
| # Non-MoE transpose case | ||
| (1, 512, 128, "transpose_shared_expert"), # Scaled from (2048, 512) | ||
| # MoE transpose cases - w2 like | ||
| (8, 512, 192, "transpose_moe_w2_small"), # 8 experts, each (512, 192) | ||
| (16, 512, 192, "transpose_moe_w2_medium"), # 16 experts | ||
| ] | ||
|
|
||
| for num_experts, M, N, name in test_cases: | ||
| G = self._create_test_matrix(num_experts, M, N) | ||
|
|
||
| # Without compile | ||
| result_no_compile = zeropower_via_newtonschulz5( | ||
| G, epsilon=self.epsilon, num_experts=num_experts | ||
| ) | ||
|
|
||
| # With compile | ||
| compiled_fn = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) | ||
| result_compile = compiled_fn(G, epsilon=self.epsilon, num_experts=num_experts) | ||
|
|
||
| # Compare results | ||
| max_diff = (result_no_compile - result_compile).abs().max().item() | ||
| assert max_diff < self.tolerance, ( | ||
| f"zeropower_via_newtonschulz5 {name} (num_experts={num_experts}): " | ||
| f"max_diff={max_diff} >= {self.tolerance}" | ||
| ) | ||
|
|
||
| def test_two_functions_consistency(self): | ||
| """Test that both functions produce similar results. | ||
|
|
||
| Compare Triton implementation with PyTorch reference implementation | ||
| using shapes from Qwen3 MoE architecture. | ||
| """ | ||
| from xtuner.v1.optim.muon import zeropower_via_newtonschulz5 | ||
| from xtuner.v1.optim.newton_schulz_triton import newton_schulz_triton | ||
|
|
||
| # Scaled-down test cases based on Qwen3 MoE config | ||
| test_cases = [ | ||
| # Non-MoE cases | ||
| (1, 1536, 512, "shared_expert_w1"), | ||
| (1, 512, 1536, "shared_expert_w2"), | ||
| # MoE w1/w3 like (M < N) | ||
| (8, 192, 512, "moe_w1"), | ||
| # MoE w2 like (M > N) | ||
| (8, 512, 192, "moe_w2"), | ||
| # Square cases | ||
| (1, 512, 512, "square_regular"), | ||
| (4, 256, 256, "square_moe"), | ||
| ] | ||
|
|
||
| for num_experts, M, N, name in test_cases: | ||
| G = self._create_test_matrix(num_experts, M, N) | ||
|
|
||
| result1 = zeropower_via_newtonschulz5( | ||
| G, epsilon=self.epsilon, num_experts=num_experts | ||
| ) | ||
| result2 = newton_schulz_triton( | ||
| G, epsilon=self.epsilon, num_experts=num_experts | ||
| ) | ||
|
|
||
| max_diff = (result1 - result2).abs().max().item() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit: The tolerance was relaxed from |
||
| # Allow larger tolerance since implementations differ (PyTorch vs Triton) | ||
| # Triton uses different kernel implementations which may have numerical differences | ||
| assert max_diff < 3e-2, ( | ||
| f"Functions differ for {name} (num_experts={num_experts}, M={M}, N={N}): " | ||
| f"max_diff={max_diff}" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Warning: Per project guidelines ("One logical change per PR. Do not mix bug fixes with features or refactors"), this documentation file should be in a separate PR. Also, this repo already has
.claude/CLAUDE.mdwith project conventions — having a second set of guidelines inAGENTS.mdat the repo root creates a risk of divergence. Consider whether this is needed at all, or if it should extend the existing CLAUDE.md.