diff --git a/xtuner/v1/config/optim.py b/xtuner/v1/config/optim.py index 5827edd8c..9b2d2b41f 100644 --- a/xtuner/v1/config/optim.py +++ b/xtuner/v1/config/optim.py @@ -62,6 +62,9 @@ class MuonConfig(OptimConfig): momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95 betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95) eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8 + adjust_lr: Annotated[ + Literal["rms_norm", "spectral_norm", "none"], Parameter(help="Method for adjusting lr in Muon") + ] = "rms_norm" def build(self, model): trainable_parameters_names = model.trainable_parameters() @@ -73,29 +76,29 @@ def build(self, model): num_muon = 0 num_adamw = 0 + muon_params = [] + adamw_params = [] + for name, p in model.named_parameters(): n = p.numel() num_total += n if name in trainable_names: num_total_requires_grad += n - is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + # we want to avoid using Muon for 1D-tensors, as well as embed_tokens and lm_head. + # effectively-1D tensors where one dimension accounts for all elements (e.g. shape [1, D]) should + # also be excluded. + is_muon_tensor = ( + p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name and p.numel() not in p.shape + ) if is_muon_tensor: + muon_params.append(p) num_muon += n else: + adamw_params.append(p) num_adamw += n else: untrainable_names.append(name) - muon_params = [ - p - for name, p in model.named_parameters() - if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ] - adamw_params = [ - p - for name, p in model.named_parameters() - if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) - ] param_groups = [ dict(params=muon_params), dict(params=adamw_params, algorithm="adamw"), @@ -103,9 +106,10 @@ def build(self, model): if dist.get_rank() == 0: logger.info( - f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" + f"Total trainable parameters: {num_total_requires_grad / 1e6:.2f}M, " + f"total parameters: {num_total / 1e6:.2f}M" ) - logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)") + logger.info(f"Muon params: {num_muon / 1e6:.2f}M, AdamW params: {num_adamw / 1e6:.2f}M (counts by numel)") logger.info(f"Untrainable parameters names: {untrainable_names}") logger.info( f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, " @@ -120,7 +124,8 @@ def build(self, model): betas=self.betas, weight_decay=self.weight_decay, nesterov=True, - adjust_lr="rms_norm", + adjust_lr=self.adjust_lr, + flatten=True, # TODO:@nil0x9 would be nice if we have fine-grained control here. use_triton=False, epsilon=self.eps, ) diff --git a/xtuner/v1/optim/muon.py b/xtuner/v1/optim/muon.py index 5a8a8200c..737921237 100644 --- a/xtuner/v1/optim/muon.py +++ b/xtuner/v1/optim/muon.py @@ -284,10 +284,10 @@ class Muon(Optimizer): weight_decay: Weight decay factor. epsilon: Small value to avoid division by zero. nesterov: Whether to use Nesterov momentum. - adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). + adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or "none"). "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. - None: Do not adjust the learning rate. + "none": Do not adjust the learning rate. flatten: Whether to flatten 3D+ tensors to 2D for Muon updates. True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. @@ -309,7 +309,7 @@ def __init__( weight_decay: float = 0.01, epsilon: float = 1e-8, nesterov: bool = False, - adjust_lr: Optional[str] = "spectral_norm", + adjust_lr: str = "rms_norm", flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, @@ -321,8 +321,8 @@ def __init__( raise ValueError(f"Invalid momentum factor (mu): {mu}") if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: raise ValueError(f"Invalid betas: {betas}") - if adjust_lr not in ("spectral_norm", "rms_norm", None): - raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None.") + if adjust_lr not in ("spectral_norm", "rms_norm", "none"): + raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or 'none'.") # Default arguments for each param group defaults = dict( @@ -552,7 +552,7 @@ def muon_update_batch_async( epsilon: Tensor, # Epsilon (scalar tensor) nesterov: bool, # Whether to use Nesterov momentum flatten: bool, # Whether to flatten 3D+ tensors to 2D - adjust_lr: Optional[str], # How to adjust learning rate + adjust_lr: str, # How to adjust learning rate device_rank: int, # Rank of the current device world_size: int, # Total number of devices to parallelize over shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) @@ -685,7 +685,7 @@ def muon_update_batch_async( # Compute scaled learning rate # Do this before to_local(X) because we use the full tensor shape, not the shard shape - if adjust_lr is None: + if adjust_lr == "none": adjusted_lr = lr elif adjust_lr == "spectral_norm": adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape)