-
Notifications
You must be signed in to change notification settings - Fork 410
[Enhance] enhance muon config #1610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,9 @@ class MuonConfig(OptimConfig): | |
| momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95 | ||
| betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95) | ||
| eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8 | ||
nil0x9 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| adjust_lr: Annotated[ | ||
| Literal["rms_norm", "spectral_norm", "none"], Parameter(help="Method for adjusting lr in Muon") | ||
| ] = "rms_norm" | ||
|
|
||
| def build(self, model): | ||
| trainable_parameters_names = model.trainable_parameters() | ||
|
|
@@ -73,39 +76,40 @@ def build(self, model): | |
| num_muon = 0 | ||
| num_adamw = 0 | ||
|
|
||
| muon_params = [] | ||
| adamw_params = [] | ||
|
|
||
| for name, p in model.named_parameters(): | ||
| n = p.numel() | ||
| num_total += n | ||
| if name in trainable_names: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit: The |
||
| num_total_requires_grad += n | ||
| is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||
| # we want to avoid using Muon for 1D-tensors, as well as embed_tokens and lm_head. | ||
| # effectively-1D tensors where one dimension accounts for all elements (e.g. shape [1, D]) should | ||
| # also be excluded. | ||
| is_muon_tensor = ( | ||
| p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name and p.numel() not in p.shape | ||
| ) | ||
| if is_muon_tensor: | ||
| muon_params.append(p) | ||
| num_muon += n | ||
| else: | ||
| adamw_params.append(p) | ||
| num_adamw += n | ||
| else: | ||
| untrainable_names.append(name) | ||
|
|
||
| muon_params = [ | ||
| p | ||
| for name, p in model.named_parameters() | ||
| if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||
| ] | ||
| adamw_params = [ | ||
| p | ||
| for name, p in model.named_parameters() | ||
| if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) | ||
| ] | ||
| param_groups = [ | ||
| dict(params=muon_params), | ||
| dict(params=adamw_params, algorithm="adamw"), | ||
| ] | ||
|
|
||
| if dist.get_rank() == 0: | ||
| logger.info( | ||
| f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" | ||
| f"Total trainable parameters: {num_total_requires_grad / 1e6:.2f}M, " | ||
| f"total parameters: {num_total / 1e6:.2f}M" | ||
| ) | ||
| logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)") | ||
| logger.info(f"Muon params: {num_muon / 1e6:.2f}M, AdamW params: {num_adamw / 1e6:.2f}M (counts by numel)") | ||
| logger.info(f"Untrainable parameters names: {untrainable_names}") | ||
| logger.info( | ||
| f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, " | ||
|
|
@@ -120,7 +124,8 @@ def build(self, model): | |
| betas=self.betas, | ||
| weight_decay=self.weight_decay, | ||
| nesterov=True, | ||
| adjust_lr="rms_norm", | ||
| adjust_lr=self.adjust_lr, | ||
| flatten=True, # TODO:@nil0x9 would be nice if we have fine-grained control here. | ||
| use_triton=False, | ||
| epsilon=self.eps, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -284,10 +284,10 @@ class Muon(Optimizer): | |
| weight_decay: Weight decay factor. | ||
| epsilon: Small value to avoid division by zero. | ||
| nesterov: Whether to use Nesterov momentum. | ||
| adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). | ||
| adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or "none"). | ||
| "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. | ||
| "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. | ||
| None: Do not adjust the learning rate. | ||
| "none": Do not adjust the learning rate. | ||
| flatten: Whether to flatten 3D+ tensors to 2D for Muon updates. | ||
| True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. | ||
| False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. | ||
|
|
@@ -309,7 +309,7 @@ def __init__( | |
| weight_decay: float = 0.01, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning: This is a breaking change — both the default value ( if adjust_lr is None:
adjust_lr = "none"or at minimum document this as a breaking change in the PR description. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning — Breaking API change: Two things changed here simultaneously:
Both changes may be intentional, but they could break downstream users of the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is intentional. If we want "rms_norm" adjusting method to be defalt behavior, then having an Optional[str] arg where None corresponds to "not adjusting" is paradoxical and confusing. |
||
| epsilon: float = 1e-8, | ||
| nesterov: bool = False, | ||
| adjust_lr: Optional[str] = "spectral_norm", | ||
| adjust_lr: str = "rms_norm", | ||
| flatten: bool = False, | ||
| use_triton: bool = False, | ||
| newton_schulz_func: Optional[Callable] = None, | ||
|
|
@@ -321,8 +321,8 @@ def __init__( | |
| raise ValueError(f"Invalid momentum factor (mu): {mu}") | ||
| if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: | ||
| raise ValueError(f"Invalid betas: {betas}") | ||
| if adjust_lr not in ("spectral_norm", "rms_norm", None): | ||
| raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None.") | ||
| if adjust_lr not in ("spectral_norm", "rms_norm", "none"): | ||
| raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or 'none'.") | ||
|
|
||
| # Default arguments for each param group | ||
| defaults = dict( | ||
|
|
@@ -552,7 +552,7 @@ def muon_update_batch_async( | |
| epsilon: Tensor, # Epsilon (scalar tensor) | ||
| nesterov: bool, # Whether to use Nesterov momentum | ||
| flatten: bool, # Whether to flatten 3D+ tensors to 2D | ||
| adjust_lr: Optional[str], # How to adjust learning rate | ||
| adjust_lr: str, # How to adjust learning rate | ||
| device_rank: int, # Rank of the current device | ||
| world_size: int, # Total number of devices to parallelize over | ||
| shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) | ||
|
|
@@ -685,7 +685,7 @@ def muon_update_batch_async( | |
|
|
||
| # Compute scaled learning rate | ||
| # Do this before to_local(X) because we use the full tensor shape, not the shard shape | ||
| if adjust_lr is None: | ||
| if adjust_lr == "none": | ||
| adjusted_lr = lr | ||
| elif adjust_lr == "spectral_norm": | ||
| adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.