Skip to content

Commit bc35c4e

Browse files
committed
[Enhance] enhance muon config
1. add adjust_lr arg to MuonConfig and refactor muon build code; 2. avoid assigning 1D params to Muon (e.g. those of shape [1, D]); 3. default flatten of Muon param to True as currently batched params are rarely used.
1 parent 8ec408b commit bc35c4e

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

xtuner/v1/config/optim.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class MuonConfig(OptimConfig):
6262
momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95
6363
betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95)
6464
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8
65+
adjust_lr: Annotated[
66+
Literal["rms_norm", "constant", "none"], Parameter(help="Method for adjusting lr in Muon ")
67+
] = ("rms_norm",)
6568

6669
def build(self, model):
6770
trainable_parameters_names = model.trainable_parameters()
@@ -73,39 +76,37 @@ def build(self, model):
7376
num_muon = 0
7477
num_adamw = 0
7578

79+
muon_params = []
80+
adamw_params = []
81+
7682
for name, p in model.named_parameters():
7783
n = p.numel()
7884
num_total += n
7985
if name in trainable_names:
8086
num_total_requires_grad += n
81-
is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
87+
is_muon_tensor = (
88+
p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name and p.numel() not in p.shape
89+
)
8290
if is_muon_tensor:
91+
muon_params.append(p)
8392
num_muon += n
8493
else:
94+
adamw_params.append(p)
8595
num_adamw += n
8696
else:
8797
untrainable_names.append(name)
8898

89-
muon_params = [
90-
p
91-
for name, p in model.named_parameters()
92-
if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
93-
]
94-
adamw_params = [
95-
p
96-
for name, p in model.named_parameters()
97-
if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name)
98-
]
9999
param_groups = [
100100
dict(params=muon_params),
101101
dict(params=adamw_params, algorithm="adamw"),
102102
]
103103

104104
if dist.get_rank() == 0:
105105
logger.info(
106-
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
106+
f"Total trainable parameters: {num_total_requires_grad / 1e6:.2f}M,"
107+
f"total parameters: {num_total / 1e6:.2f}M"
107108
)
108-
logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)")
109+
logger.info(f"Muon params: {num_muon / 1e6:.2f}M, AdamW params: {num_adamw / 1e6:.2f}M (counts by numel)")
109110
logger.info(f"Untrainable parameters names: {untrainable_names}")
110111
logger.info(
111112
f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, "
@@ -121,6 +122,7 @@ def build(self, model):
121122
weight_decay=self.weight_decay,
122123
nesterov=True,
123124
adjust_lr="rms_norm",
125+
flatten=True, # TODO:@nil0x9 would be nice if we hace fine-grained control here.
124126
use_triton=False,
125127
epsilon=self.eps,
126128
)

xtuner/v1/optim/muon.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ class Muon(Optimizer):
284284
weight_decay: Weight decay factor.
285285
epsilon: Small value to avoid division by zero.
286286
nesterov: Whether to use Nesterov momentum.
287-
adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None).
287+
adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or "none").
288288
"spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale.
289289
"rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW.
290290
None: Do not adjust the learning rate.
@@ -309,7 +309,7 @@ def __init__(
309309
weight_decay: float = 0.01,
310310
epsilon: float = 1e-8,
311311
nesterov: bool = False,
312-
adjust_lr: Optional[str] = "spectral_norm",
312+
adjust_lr: str = "rms_norm",
313313
flatten: bool = False,
314314
use_triton: bool = False,
315315
newton_schulz_func: Optional[Callable] = None,
@@ -321,8 +321,8 @@ def __init__(
321321
raise ValueError(f"Invalid momentum factor (mu): {mu}")
322322
if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0:
323323
raise ValueError(f"Invalid betas: {betas}")
324-
if adjust_lr not in ("spectral_norm", "rms_norm", None):
325-
raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None.")
324+
if adjust_lr not in ("spectral_norm", "rms_norm", "none"):
325+
raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or 'none'.")
326326

327327
# Default arguments for each param group
328328
defaults = dict(
@@ -552,7 +552,7 @@ def muon_update_batch_async(
552552
epsilon: Tensor, # Epsilon (scalar tensor)
553553
nesterov: bool, # Whether to use Nesterov momentum
554554
flatten: bool, # Whether to flatten 3D+ tensors to 2D
555-
adjust_lr: Optional[str], # How to adjust learning rate
555+
adjust_lr: str, # How to adjust learning rate
556556
device_rank: int, # Rank of the current device
557557
world_size: int, # Total number of devices to parallelize over
558558
shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable)
@@ -685,7 +685,7 @@ def muon_update_batch_async(
685685

686686
# Compute scaled learning rate
687687
# Do this before to_local(X) because we use the full tensor shape, not the shard shape
688-
if adjust_lr is None:
688+
if adjust_lr == "none":
689689
adjusted_lr = lr
690690
elif adjust_lr == "spectral_norm":
691691
adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape)

0 commit comments

Comments
 (0)