diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..5f15a4805 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,59 @@ +# AGENTS.md - Project Guidelines for AI Assistants + +This file contains project-specific guidelines and conventions for AI assistants working on this codebase. + +## Git Commit Message Guidelines + +### Format + +``` +[] + + + +- Bullet point for specific changes +- Another bullet point +``` + +### Types + +- `[Feature]` - New feature +- `[Fix]` - Bug fix +- `[Refactor]` - Code refactoring +- `[Docs]` - Documentation changes +- `[Test]` - Test changes +- `[Chore]` - Build/tooling changes + +### Guidelines + +1. **Short summary**: Concise description of the change (50 chars or less) +2. **Long description**: Explain **what** changed and **why**, not **how** +3. **No bullet points**: Do not list specific changes in commit message +4. **No file lists**: Do not include file names or "Files modified:" section +5. **Keep it brief**: Only high-level functional description, details go to PR description + +### Example + +``` +[Fix] Muon optimizer per-expert orthogonalization for MoE models + +Fix Muon optimizer to apply orthogonalization per expert matrix instead of +to the concatenated large matrix for MoE models. +``` + +## PR Description Guidelines + +The PR description should contain: + +1. **Summary**: Brief overview of the changes +2. **Motivation**: Why this change is needed +3. **Changes**: Detailed list of what changed +4. **Files modified**: List of files changed +5. **Testing**: How the changes were tested + +## Code Style + +- Follow existing code style in the project +- Add type hints for new functions +- Add docstrings for public functions and classes +- Keep functions focused and small diff --git a/tests/optim/test_muon_compile.py b/tests/optim/test_muon_compile.py new file mode 100644 index 000000000..8aa9eb84b --- /dev/null +++ b/tests/optim/test_muon_compile.py @@ -0,0 +1,247 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Test Muon optimizer Newton-Schulz functions with/without torch.compile. + +Test shapes are based on Qwen3-30B-A3B model config: +- hidden_size: 2048 +- num_experts: 128 +- moe_intermediate_size: 768 +- intermediate_size: 6144 (for shared expert) + +MoE expert weight shapes: +- w1/w3: (num_experts * moe_intermediate_size, hidden_size) = (98304, 2048) + per expert: (768, 2048) +- w2: (hidden_size, num_experts * moe_intermediate_size) = (2048, 98304) + per expert: (2048, 768) + +For testing, we use scaled-down versions to keep tests fast while maintaining +representative shapes. + +================================================================================ +IMPORTANT: DTensor Compatibility Note +================================================================================ + +The zeropower_via_newtonschulz5 function supports DTensor input, but with a +known limitation when M > N (e.g., w2 weights where hidden_size > moe_intermediate_size). + +Root Cause Analysis (verified by /tmp/test_dtensor_root_cause_detailed.py): +--------------------------------------------------------------------------- +When M > N, the Newton-Schulz algorithm transposes the input matrix: + X = G.view(1, M, N).mT # becomes (1, N, M) + +For a DTensor sharded on dim 0 (M dimension): + 1. After view(1, M, N): placements become Shard(dim=1) + 2. After mT: placements become Shard(dim=2) # the M dimension moves to dim 2 + 3. X @ X.mT produces Partial(sum) DTensor # contraction dim is sharded + 4. Partial values are not correctly reduced in subsequent operations + 5. Error accumulates across 5 Newton-Schulz iterations: + Iter 1: X max ~0.016 + Iter 2: X max ~0.060 + Iter 3: X max ~0.099 + Iter 4: X max ~0.29 + Iter 5: X max ~47.5 (EXPLOSION!) + 6. Final result is completely wrong (e.g., 0.1 -> 47.5) + +Verification Results: + - M < N (w1/w3): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly + - M > N (w2): ✗ FAIL - A @ A.mT produces Partial(sum), results explode + - M = N (square): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly + +Workaround: + For DTensor with M > N (w2 weights), convert to local tensor: + result = zeropower_via_newtonschulz5(dtensor.to_local(), num_experts=1) + +Note: + This is NOT a torch.compile issue. The same problem occurs with or without + torch.compile. It's a fundamental limitation of DTensor's Partial placement + handling in complex matrix operation chains. + +newton_schulz_triton: + Does not support DTensor at all due to direct Triton kernel usage. + Must use .to_local() to convert before calling. +================================================================================ +""" + +import pytest +import torch + +# Skip all tests if CUDA is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +class TestNewtonSchulzCompile: + """Test Newton-Schulz functions with and without torch.compile.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test fixtures.""" + self.device = "cuda" + self.dtype = torch.bfloat16 + self.epsilon = 1e-7 + self.tolerance = 1e-3 # Tolerance for bfloat16 comparison + + def _create_test_matrix(self, num_experts, M, N): + """Create a test matrix with given dimensions.""" + shape = (num_experts * M, N) + return torch.randn(shape, device=self.device, dtype=torch.float32) + + def test_zeropower_via_newtonschulz5_compile(self): + """Test muon.zeropower_via_newtonschulz5 with/without compile. + + Test cases based on Qwen3 MoE architecture (hidden_size=2048, num_experts=128): + - Non-MoE: (6144, 2048) and (2048, 6144) for shared experts + - MoE w1/w3: (128 * 768, 2048) per expert (768, 2048) + - MoE w2: (2048, 128 * 768) per expert (2048, 768) + """ + from xtuner.v1.optim.muon import zeropower_via_newtonschulz5 + + # Scaled-down test cases based on Qwen3 MoE config + test_cases = [ + # Non-MoE cases (shared expert-like) + (1, 1536, 512, "shared_expert_w1"), # (1536, 512) scaled from (6144, 2048) + (1, 512, 1536, "shared_expert_w2"), # (512, 1536) scaled from (2048, 6144) + # MoE cases - w1/w3 like (M < N) + (8, 192, 512, "moe_w1_small"), # per expert: (192, 512) scaled from (768, 2048) + (16, 192, 512, "moe_w1_medium"), # 16 experts + # MoE cases - w2 like (M > N) + (8, 512, 192, "moe_w2_small"), # per expert: (512, 192) scaled from (2048, 768) + (16, 512, 192, "moe_w2_medium"), # 16 experts + # Square cases + (1, 512, 512, "square_regular"), + (4, 256, 256, "square_moe"), + ] + + for num_experts, M, N, name in test_cases: + G = self._create_test_matrix(num_experts, M, N) + + # Without compile + result_no_compile = zeropower_via_newtonschulz5( + G, epsilon=self.epsilon, num_experts=num_experts + ) + + # With compile + compiled_fn = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + result_compile = compiled_fn(G, epsilon=self.epsilon, num_experts=num_experts) + + # Compare results + max_diff = (result_no_compile - result_compile).abs().max().item() + assert max_diff < self.tolerance, ( + f"{name} (num_experts={num_experts}, M={M}, N={N}): " + f"max_diff={max_diff} >= {self.tolerance}" + ) + + def test_newton_schulz_triton(self): + """Test newton_schulz_triton (Triton kernel, no torch.compile). + + Note: Triton kernel is not compatible with torch.compile, so we only test + without compile and verify basic correctness. + """ + from xtuner.v1.optim.newton_schulz_triton import newton_schulz_triton + + # Scaled-down test cases based on Qwen3 MoE config + test_cases = [ + # Non-MoE cases (shared expert-like) + (1, 1536, 512, "shared_expert_w1"), # (1536, 512) + (1, 512, 1536, "shared_expert_w2"), # (512, 1536) + # MoE cases - w1/w3 like (M < N) + (8, 192, 512, "moe_w1_small"), # 8 experts, each (192, 512) + (16, 192, 512, "moe_w1_medium"), # 16 experts + # MoE cases - w2 like (M > N) + (8, 512, 192, "moe_w2_small"), # 8 experts, each (512, 192) + (16, 512, 192, "moe_w2_medium"), # 16 experts + # Square cases + (1, 512, 512, "square_regular"), + (4, 256, 256, "square_moe"), + ] + + for num_experts, M, N, name in test_cases: + G = self._create_test_matrix(num_experts, M, N) + + # Test without compile (Triton kernel doesn't support compile) + result = newton_schulz_triton(G, epsilon=self.epsilon, num_experts=num_experts) + + # Basic sanity check: output should have correct shape + assert result.shape == G.shape, f"{name}: output shape mismatch" + + # Output should not be all zeros or contain NaN/Inf + assert not torch.isnan(result).any(), f"{name}: output contains NaN" + assert not torch.isinf(result).any(), f"{name}: output contains Inf" + assert result.abs().max() > 0, f"{name}: output is all zeros" + + def test_transpose_case_compile(self): + """Test matrices where rows > cols (transpose case) with compile. + + Based on Qwen3 MoE w2 shape: (hidden_size, num_experts * moe_intermediate_size) + """ + from xtuner.v1.optim.muon import zeropower_via_newtonschulz5 + + test_cases = [ + # Non-MoE transpose case + (1, 512, 128, "transpose_shared_expert"), # Scaled from (2048, 512) + # MoE transpose cases - w2 like + (8, 512, 192, "transpose_moe_w2_small"), # 8 experts, each (512, 192) + (16, 512, 192, "transpose_moe_w2_medium"), # 16 experts + ] + + for num_experts, M, N, name in test_cases: + G = self._create_test_matrix(num_experts, M, N) + + # Without compile + result_no_compile = zeropower_via_newtonschulz5( + G, epsilon=self.epsilon, num_experts=num_experts + ) + + # With compile + compiled_fn = torch.compile(zeropower_via_newtonschulz5, fullgraph=True) + result_compile = compiled_fn(G, epsilon=self.epsilon, num_experts=num_experts) + + # Compare results + max_diff = (result_no_compile - result_compile).abs().max().item() + assert max_diff < self.tolerance, ( + f"zeropower_via_newtonschulz5 {name} (num_experts={num_experts}): " + f"max_diff={max_diff} >= {self.tolerance}" + ) + + def test_two_functions_consistency(self): + """Test that both functions produce similar results. + + Compare Triton implementation with PyTorch reference implementation + using shapes from Qwen3 MoE architecture. + """ + from xtuner.v1.optim.muon import zeropower_via_newtonschulz5 + from xtuner.v1.optim.newton_schulz_triton import newton_schulz_triton + + # Scaled-down test cases based on Qwen3 MoE config + test_cases = [ + # Non-MoE cases + (1, 1536, 512, "shared_expert_w1"), + (1, 512, 1536, "shared_expert_w2"), + # MoE w1/w3 like (M < N) + (8, 192, 512, "moe_w1"), + # MoE w2 like (M > N) + (8, 512, 192, "moe_w2"), + # Square cases + (1, 512, 512, "square_regular"), + (4, 256, 256, "square_moe"), + ] + + for num_experts, M, N, name in test_cases: + G = self._create_test_matrix(num_experts, M, N) + + result1 = zeropower_via_newtonschulz5( + G, epsilon=self.epsilon, num_experts=num_experts + ) + result2 = newton_schulz_triton( + G, epsilon=self.epsilon, num_experts=num_experts + ) + + max_diff = (result1 - result2).abs().max().item() + # Allow larger tolerance since implementations differ (PyTorch vs Triton) + # Triton uses different kernel implementations which may have numerical differences + assert max_diff < 3e-2, ( + f"Functions differ for {name} (num_experts={num_experts}, M={M}, N={N}): " + f"max_diff={max_diff}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/xtuner/v1/config/optim.py b/xtuner/v1/config/optim.py index 5827edd8c..11c8840e8 100644 --- a/xtuner/v1/config/optim.py +++ b/xtuner/v1/config/optim.py @@ -49,7 +49,7 @@ def build(self, model): if dist.get_rank() == 0: logger.info( - f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" + f"Total trainable parameters: {num_total_requires_grad / 1e6:.2f}M, total parameters: {num_total / 1e6:.2f}M" ) logger.info(f"Untrainable parameters names: {untrainable_names}") return torch.optim.AdamW( @@ -71,8 +71,19 @@ def build(self, model): num_total = 0 num_total_requires_grad = 0 num_muon = 0 + num_muon_moe = 0 num_adamw = 0 + # Get MoE config if available + num_experts = getattr(model.config, "n_routed_experts", 1) or 1 + is_moe_model = num_experts > 1 + + # Expert parameter patterns for MoE models + # Note: fused_w1w3 contains both w1 and w3 weights, so num_experts = 2 * n_routed_experts + fused_w1w3_patterns = ("fused_w1w3",) + other_expert_patterns = ("fused_w2", "fused_w1", "fused_w3") + all_expert_patterns = fused_w1w3_patterns + other_expert_patterns + for name, p in model.named_parameters(): n = p.numel() num_total += n @@ -80,32 +91,69 @@ def build(self, model): num_total_requires_grad += n is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name if is_muon_tensor: - num_muon += n + # Check if this is an MoE expert parameter + if is_moe_model and any(pattern in name for pattern in all_expert_patterns): + num_muon_moe += n + else: + num_muon += n else: num_adamw += n else: untrainable_names.append(name) - muon_params = [ - p - for name, p in model.named_parameters() - if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ] + # Separate Muon params into regular and MoE expert params + # fused_w1w3 has 2 * num_experts (w1 and w3 each have num_experts) + # other expert params have num_experts + muon_params_regular = [] + muon_params_moe_fused_w1w3 = [] # num_experts = 2 * n_routed_experts + muon_params_moe_other = [] # num_experts = n_routed_experts + + for name, p in model.named_parameters(): + if name in trainable_names: + is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + if is_muon_tensor: + if is_moe_model and any(pattern in name for pattern in fused_w1w3_patterns): + muon_params_moe_fused_w1w3.append(p) + elif is_moe_model and any(pattern in name for pattern in other_expert_patterns): + muon_params_moe_other.append(p) + else: + muon_params_regular.append(p) + adamw_params = [ p for name, p in model.named_parameters() if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) ] - param_groups = [ - dict(params=muon_params), - dict(params=adamw_params, algorithm="adamw"), - ] + + # Build parameter groups + param_groups = [] + if muon_params_regular: + param_groups.append(dict(params=muon_params_regular)) + # fused_w1w3: w1 and w3 are fused, so num_experts = 2 * n_routed_experts + if muon_params_moe_fused_w1w3: + param_groups.append(dict(params=muon_params_moe_fused_w1w3, num_experts=2 * num_experts)) + # Other expert params: num_experts = n_routed_experts + if muon_params_moe_other: + param_groups.append(dict(params=muon_params_moe_other, num_experts=num_experts)) + param_groups.append(dict(params=adamw_params, algorithm="adamw")) if dist.get_rank() == 0: logger.info( - f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" + f"Total trainable parameters: {num_total_requires_grad / 1e6:.2f}M, total parameters: {num_total / 1e6:.2f}M" ) - logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)") + if is_moe_model: + logger.info( + f"Muon params: {(num_muon + num_muon_moe) / 1e6:.2f}M " + f"(regular: {num_muon / 1e6:.2f}M, MoE expert: {num_muon_moe / 1e6:.2f}M), " + f"AdamW params: {num_adamw / 1e6:.2f}M (counts by numel)" + ) + logger.info( + f"Detected MoE model with {num_experts} routed experts, " + f"fused_w1w3 uses num_experts={2 * num_experts} (w1+w3), " + f"other expert params use num_experts={num_experts}" + ) + else: + logger.info(f"Muon params: {num_muon / 1e6:.2f}M, AdamW params: {num_adamw / 1e6:.2f}M (counts by numel)") logger.info(f"Untrainable parameters names: {untrainable_names}") logger.info( f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, " diff --git a/xtuner/v1/optim/muon.py b/xtuner/v1/optim/muon.py index 5a8a8200c..3bdf9bfa1 100644 --- a/xtuner/v1/optim/muon.py +++ b/xtuner/v1/optim/muon.py @@ -274,7 +274,9 @@ class Muon(Optimizer): """Distributed Muon optimizer for PyTorch FSDP2. Also compatible with DDP. Args: - params: Parameters for the optimizer. + params: Parameters for the optimizer. Can be a list of parameters or a list of + parameter groups. Each parameter group can specify 'num_experts' to enable + per-expert orthogonalization for MoE models. distributed_mesh: DeviceMesh or ProcessGroup for distributed training. Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. lr: Base learning rate. For Muon, this will be scaled based on the matrix dimensions. @@ -293,7 +295,7 @@ class Muon(Optimizer): False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. - Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + Signature is `func(input: Tensor, epsilon: float, num_experts: int) -> Tensor`. Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -337,6 +339,7 @@ def __init__( nesterov=nesterov, flatten=flatten, adjust_lr=adjust_lr, + num_experts=1, # Default: no MoE expert handling ) super().__init__(params, defaults) @@ -442,6 +445,7 @@ def _create_muon_tasks( nesterov = group["nesterov"] flatten = group["flatten"] adjust_lr = group["adjust_lr"] + num_experts = group.get("num_experts", 1) # Create batches of parameters of size self._world_size for params in create_param_batches(group_params, batch_size=self._world_size): @@ -497,6 +501,7 @@ def _create_muon_tasks( shard_dim=sharded_tensor_dim, process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, + num_experts=num_experts, ) ) @@ -558,6 +563,7 @@ def muon_update_batch_async( shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, + num_experts: int = 1, # Number of experts for MoE models ) -> Generator[None, None, None]: """Batched version of Muon update. @@ -620,6 +626,7 @@ def muon_update_batch_async( newton_schulz_func=newton_schulz_func, flatten=flatten, epsilon=epsilon, + num_experts=num_experts, ) # Prepare to scatter results back @@ -659,6 +666,7 @@ def muon_update_batch_async( newton_schulz_func=newton_schulz_func, flatten=flatten, epsilon=epsilon, + num_experts=num_experts, ) if process_group is not None and process_group.size() > 1: @@ -688,9 +696,9 @@ def muon_update_batch_async( if adjust_lr is None: adjusted_lr = lr elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, num_experts=num_experts) elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, num_experts=num_experts) else: raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") @@ -779,6 +787,7 @@ def muon_update_newton_schulz( newton_schulz_func: Callable, flatten: bool, epsilon: Tensor, + num_experts: int = 1, ) -> Tensor: """Flatten the input tensor if needed and call the Newton-Schulz function.""" @@ -790,30 +799,57 @@ def muon_update_newton_schulz( # Given 4D+ batch, flatten to 3D batch X = X.flatten(end_dim=-3) - return newton_schulz_func(X, epsilon=epsilon).reshape(original_shape) + return newton_schulz_func(X, epsilon=epsilon, num_experts=num_experts).reshape(original_shape) -def adjust_lr_rms_norm(lr, param_shape): +def adjust_lr_rms_norm(lr, param_shape, num_experts=1): # Adjust learning rate for constant element-wise RMS norm # https://arxiv.org/abs/2502.16982 - A, B = param_shape[:2] + A = param_shape[-2] // num_experts + B = param_shape[-1] adjusted_ratio = 0.2 * math.sqrt(max(A, B)) adjusted_lr = lr * adjusted_ratio return adjusted_lr -def adjust_lr_spectral_norm(lr, param_shape): +def adjust_lr_spectral_norm(lr, param_shape, num_experts=1): # Adjust from spectral norm 1 to RMS operator norm 1 # https://arxiv.org/abs/2310.17813 - fan_out, fan_in = param_shape[:2] + fan_out = param_shape[-2] // num_experts + fan_in = param_shape[-1] adjusted_lr = lr * math.sqrt(fan_out / fan_in) return adjusted_lr @torch.compile(fullgraph=True) -def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): - """Newton-Schulz iteration to approximate the orthogonalization of X.""" - # Newton-Schulz constants +def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7, num_experts: int = 1): + """Newton-Schulz iteration to approximate the orthogonalization of X. + + This function handles both regular matrices and MoE expert weight matrices. + For MoE models, each expert's weight matrix is orthogonalized independently, + rather than orthogonalizing the concatenated large matrix. + + Unified algorithm for both cases: + 1. Reshape input to (num_experts, M, N) - for regular case this is (1, M, N) + 2. Apply Newton-Schulz iteration to each expert matrix independently using + batch matrix multiplication + 3. Reshape back to original shape + + Mathematical equivalence: + - num_experts=1: X.view(1, M, N) -> process -> X.view(M, N) + This is mathematically equivalent to processing X directly, but allows + unified code path with the MoE case. + - num_experts>1: X.view(num_experts, M, N) -> process each expert -> X.view(num_experts*M, N) + Each expert matrix is orthogonalized independently with its own spectral norm. + + Args: + G: Input tensor to orthogonalize. Shape: (num_experts * M, N) for MoE, + or (M, N) for regular matrices. + epsilon: Small value to avoid division by zero. + num_experts: Number of experts for MoE models. Default 1 for regular matrices. + When > 1, the input is treated as concatenated expert matrices. + """ + # Newton-Schulz constants - fixed coefficients for 5th order iteration ns_consts = [ (4.0848, -6.8946, 2.9270), (3.9505, -6.3029, 2.6377), @@ -823,17 +859,38 @@ def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): ] X = G.to(dtype=torch.bfloat16) - if G.size(-2) > G.size(-1): - X = X.mT + original_shape = X.shape + + # Unified handling: reshape to (num_experts, M, N) for both cases + # For regular case (num_experts=1), this adds a batch dimension of size 1 + N = X.size(-1) + X = X.view(num_experts, -1, N) + + # Transpose if needed (when rows > cols) for numerical stability in NS iteration + # This ensures X @ X.mT produces a smaller square matrix + need_transpose = X.size(-2) > X.size(-1) + if need_transpose: + X = X.mT # (num_experts, N, M) if rows > cols, else (num_experts, M, N) - # Ensure spectral norm is at most 1 + # Ensure spectral norm is at most 1 for each expert matrix independently + # norm shape: (num_experts, 1, 1) - each expert has its own normalization factor X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) + # Newton-Schulz iteration: orthogonalize each expert matrix + # Using batch matrix multiplication (@) to process all experts in parallel for a, b, c in ns_consts: - A = X @ X.mT + # A = X @ X^T: compute Gram matrix for each expert + A = X @ X.mT # shape: (num_experts, M, M) or (num_experts, N, N) + # B = b * A + c * A @ A: polynomial combination for convergence B = b * A + c * (A @ A) + # X = a * X + B @ X: update step X = a * X + B @ X - if G.size(-2) > G.size(-1): + # Undo transpose if applied + if need_transpose: X = X.mT + + # Reshape back to original shape: (num_experts, M, N) -> (num_experts * M, N) + X = X.view(original_shape) + return X diff --git a/xtuner/v1/optim/newton_schulz_triton.py b/xtuner/v1/optim/newton_schulz_triton.py index c9aa8c91e..9337592f8 100644 --- a/xtuner/v1/optim/newton_schulz_triton.py +++ b/xtuner/v1/optim/newton_schulz_triton.py @@ -283,37 +283,18 @@ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor | None = None # @torch.compile(dynamic=False, fullgraph=True) -def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): - """Reference implementation of Newton-Schulz without Triton.""" - # Newton-Schulz constants - ns_consts = [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ] - - X = G.to(dtype=torch.bfloat16) - if G.size(-2) > G.size(-1): - X = X.mT +def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7, num_experts: int = 1): + """Triton implementation of Newton-Schulz iteration. - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) + Unified implementation for both regular matrices and MoE expert matrices. + Uses reshape to (num_experts, M, N) to handle both cases with the same code path. - for a, b, c in ns_consts: - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - - if G.size(-2) > G.size(-1): - X = X.mT - return X - - -# @torch.compile(dynamic=False, fullgraph=True) -def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7): - """Triton implementation of Newton-Schulz iteration.""" + Args: + G: Input tensor to orthogonalize. Shape: (num_experts * M, N) for MoE, + or (M, N) for regular matrices. + epsilon: Small value to avoid division by zero. + num_experts: Number of experts for MoE models. Default 1 for regular matrices. + """ # Newton-Schulz constants ns_consts = [ (4.0848, -6.8946, 2.9270), @@ -324,27 +305,39 @@ def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7): ] X = G.to(dtype=torch.bfloat16) - if G.size(-2) > G.size(-1): + original_shape = X.shape + + # Unified reshape: (num_experts * M, N) -> (num_experts, M, N) + # For num_experts=1, this is (M, N) -> (1, M, N), adding a batch dim + N = X.size(-1) + X = X.view(num_experts, -1, N) + + # Transpose if rows > cols for numerical stability + need_transpose = X.size(-2) > X.size(-1) + if need_transpose: X = X.mT - # Ensure spectral norm is at most 1 + # Normalize each expert matrix independently + # norm shape: (num_experts, 1, 1) - each expert has its own normalization factor X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) - # Allocate buffers + # Allocate buffers for 3D tensors X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + A = torch.empty((num_experts, X.size(-2), X.size(-2)), device=X.device, dtype=X.dtype) B = torch.empty_like(A) C = torch.empty_like(X) - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations + # Perform the NS iterations using batch matrix operations for a, b, c in ns_consts: ns_line_1(X, out=A) # A = X @ X.mT ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + torch.baddbmm(X, B, X, beta=a, out=C) # C = a * X + B @ X X, C = C, X # Swap references to avoid unnecessary copies - if G.size(-2) > G.size(-1): + if need_transpose: X = X.mT + + # Reshape back to original shape + X = X.reshape(original_shape) + return X