Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions xtuner/v1/config/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class MuonConfig(OptimConfig):
momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95
betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95)
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8
adjust_lr: Annotated[
Literal["rms_norm", "spectral_norm", "none"], Parameter(help="Method for adjusting lr in Muon")
] = "rms_norm"

def build(self, model):
trainable_parameters_names = model.trainable_parameters()
Expand All @@ -73,39 +76,40 @@ def build(self, model):
num_muon = 0
num_adamw = 0

muon_params = []
adamw_params = []

for name, p in model.named_parameters():
n = p.numel()
num_total += n
if name in trainable_names:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: The p.numel() not in p.shape check is clever but non-obvious. A brief comment explaining the intent would help future readers — e.g., "exclude effectively-1D tensors where one dimension accounts for all elements (e.g. shape [1, D])."

num_total_requires_grad += n
is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
# we want to avoid using Muon for 1D-tensors, as well as embed_tokens and lm_head.
# effectively-1D tensors where one dimension accounts for all elements (e.g. shape [1, D]) should
# also be excluded.
is_muon_tensor = (
p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name and p.numel() not in p.shape
)
if is_muon_tensor:
muon_params.append(p)
num_muon += n
else:
adamw_params.append(p)
num_adamw += n
else:
untrainable_names.append(name)

muon_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
]
adamw_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name)
]
param_groups = [
dict(params=muon_params),
dict(params=adamw_params, algorithm="adamw"),
]

if dist.get_rank() == 0:
logger.info(
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
f"Total trainable parameters: {num_total_requires_grad / 1e6:.2f}M, "
f"total parameters: {num_total / 1e6:.2f}M"
)
logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)")
logger.info(f"Muon params: {num_muon / 1e6:.2f}M, AdamW params: {num_adamw / 1e6:.2f}M (counts by numel)")
logger.info(f"Untrainable parameters names: {untrainable_names}")
logger.info(
f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, "
Expand All @@ -120,7 +124,8 @@ def build(self, model):
betas=self.betas,
weight_decay=self.weight_decay,
nesterov=True,
adjust_lr="rms_norm",
adjust_lr=self.adjust_lr,
flatten=True, # TODO:@nil0x9 would be nice if we have fine-grained control here.
use_triton=False,
epsilon=self.eps,
)
Expand Down
14 changes: 7 additions & 7 deletions xtuner/v1/optim/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,10 @@ class Muon(Optimizer):
weight_decay: Weight decay factor.
epsilon: Small value to avoid division by zero.
nesterov: Whether to use Nesterov momentum.
adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None).
adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or "none").
"spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale.
"rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW.
None: Do not adjust the learning rate.
"none": Do not adjust the learning rate.
flatten: Whether to flatten 3D+ tensors to 2D for Muon updates.
True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers.
False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices.
Expand All @@ -309,7 +309,7 @@ def __init__(
weight_decay: float = 0.01,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: This is a breaking change — both the default value ("spectral_norm""rms_norm") and the type (Optional[str]str, None"none" string) changed. Any existing callers passing adjust_lr=None will now get a ValueError. Consider keeping backward compatibility:

if adjust_lr is None:
    adjust_lr = "none"

or at minimum document this as a breaking change in the PR description.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Breaking API change: Two things changed here simultaneously:

  1. The type changed from Optional[str] to str — any caller passing adjust_lr=None will now get a ValueError.
  2. The default changed from "spectral_norm" to "rms_norm" — existing callers relying on the default will silently get different behavior.

Both changes may be intentional, but they could break downstream users of the Muon class directly. Consider either:

  • Accepting None as a deprecated alias for "none" (with a deprecation warning), or
  • Documenting this as a known breaking change in the PR description.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is intentional. If we want "rms_norm" adjusting method to be defalt behavior, then having an Optional[str] arg where None corresponds to "not adjusting" is paradoxical and confusing.

epsilon: float = 1e-8,
nesterov: bool = False,
adjust_lr: Optional[str] = "spectral_norm",
adjust_lr: str = "rms_norm",
flatten: bool = False,
use_triton: bool = False,
newton_schulz_func: Optional[Callable] = None,
Expand All @@ -321,8 +321,8 @@ def __init__(
raise ValueError(f"Invalid momentum factor (mu): {mu}")
if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0:
raise ValueError(f"Invalid betas: {betas}")
if adjust_lr not in ("spectral_norm", "rms_norm", None):
raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None.")
if adjust_lr not in ("spectral_norm", "rms_norm", "none"):
raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or 'none'.")

# Default arguments for each param group
defaults = dict(
Expand Down Expand Up @@ -552,7 +552,7 @@ def muon_update_batch_async(
epsilon: Tensor, # Epsilon (scalar tensor)
nesterov: bool, # Whether to use Nesterov momentum
flatten: bool, # Whether to flatten 3D+ tensors to 2D
adjust_lr: Optional[str], # How to adjust learning rate
adjust_lr: str, # How to adjust learning rate
device_rank: int, # Rank of the current device
world_size: int, # Total number of devices to parallelize over
shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable)
Expand Down Expand Up @@ -685,7 +685,7 @@ def muon_update_batch_async(

# Compute scaled learning rate
# Do this before to_local(X) because we use the full tensor shape, not the shard shape
if adjust_lr is None:
if adjust_lr == "none":
adjusted_lr = lr
elif adjust_lr == "spectral_norm":
adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape)
Expand Down
Loading