@@ -284,7 +284,7 @@ class Muon(Optimizer):
284284 weight_decay: Weight decay factor.
285285 epsilon: Small value to avoid division by zero.
286286 nesterov: Whether to use Nesterov momentum.
287- adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None ).
287+ adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or "none" ).
288288 "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale.
289289 "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW.
290290 None: Do not adjust the learning rate.
@@ -309,7 +309,7 @@ def __init__(
309309 weight_decay : float = 0.01 ,
310310 epsilon : float = 1e-8 ,
311311 nesterov : bool = False ,
312- adjust_lr : Optional [ str ] = "spectral_norm " ,
312+ adjust_lr : str = "rms_norm " ,
313313 flatten : bool = False ,
314314 use_triton : bool = False ,
315315 newton_schulz_func : Optional [Callable ] = None ,
@@ -321,8 +321,8 @@ def __init__(
321321 raise ValueError (f"Invalid momentum factor (mu): { mu } " )
322322 if len (betas ) != 2 or betas [0 ] < 0.0 or betas [1 ] < 0.0 :
323323 raise ValueError (f"Invalid betas: { betas } " )
324- if adjust_lr not in ("spectral_norm" , "rms_norm" , None ):
325- raise ValueError (f"Invalid adjust_lr value: { adjust_lr } . Must be 'spectral_norm', 'rms_norm', or None ." )
324+ if adjust_lr not in ("spectral_norm" , "rms_norm" , "none" ):
325+ raise ValueError (f"Invalid adjust_lr value: { adjust_lr } . Must be 'spectral_norm', 'rms_norm', or 'none' ." )
326326
327327 # Default arguments for each param group
328328 defaults = dict (
@@ -552,7 +552,7 @@ def muon_update_batch_async(
552552 epsilon : Tensor , # Epsilon (scalar tensor)
553553 nesterov : bool , # Whether to use Nesterov momentum
554554 flatten : bool , # Whether to flatten 3D+ tensors to 2D
555- adjust_lr : Optional [ str ] , # How to adjust learning rate
555+ adjust_lr : str , # How to adjust learning rate
556556 device_rank : int , # Rank of the current device
557557 world_size : int , # Total number of devices to parallelize over
558558 shard_dim : Optional [int ] = None , # Shard dimension for DTensor (if applicable)
@@ -685,7 +685,7 @@ def muon_update_batch_async(
685685
686686 # Compute scaled learning rate
687687 # Do this before to_local(X) because we use the full tensor shape, not the shard shape
688- if adjust_lr is None :
688+ if adjust_lr == "none" :
689689 adjusted_lr = lr
690690 elif adjust_lr == "spectral_norm" :
691691 adjusted_lr = adjust_lr_spectral_norm (lr , X [0 ].shape )
0 commit comments