Skip to content

[Fix] Support fp32 param preservation during FSDP save and load#1607

Merged
HAOCHENYE merged 1 commit intoInternLM:mainfrom
HAOCHENYE:fp32-param-train
Mar 23, 2026
Merged

[Fix] Support fp32 param preservation during FSDP save and load#1607
HAOCHENYE merged 1 commit intoInternLM:mainfrom
HAOCHENYE:fp32-param-train

Conversation

@HAOCHENYE
Copy link
Collaborator

  • Add fp32_keys_pattern to HFSaveCfg to specify params that should
    be saved in fp32 regardless of the global save dtype
  • Add _fully_shard() to BaseModel which wraps fully_shard() and
    distributes matched params as Replicate DTensors so FSDP ignores them
  • Replace all direct fully_shard() call sites with self._fully_shard()
  • Add _get_save_dtype() to select fp32 per-param at save time
  • Fix load path in _load_same_hf_param to skip shard-offset logic for
    Replicate DTensors (only apply for Shard-placed params)
  • Fix world_mesh property setter bug: use _world_mesh directly
    instead of assigning through the read-only property
  • Fix gradient reduce in MoE.scale_and_reduce_grad to correctly
    identify the Replicate mesh dimension via DTensor placements
  • Preserve requires_grad when distributing params in EP and fp32 paths
  • Configure Qwen3_5_VLTextMoEConfig with fp32 patterns for
    linear_attn.norm.weight and linear_attn.A_log
  • Add test_save_hf_with_mtp to verify round-trip weight preservation

@HAOCHENYE
Copy link
Collaborator Author

@claude review

Comment on lines +876 to +878
if isinstance(param, DTensor):
replicate_placements = [p for p in param.placements if isinstance(p, Replicate)]
if replicate_placements:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Only the first Replicate dimension is reduced.

replicate_placements may contain more than one entry (e.g. for fp32-ignored params distributed as [Replicate(), Replicate()] on the 2D world_mesh). This loop only reduces across the first Replicate dimension, silently skipping subsequent ones.

Today this isn't triggered because matching fp32 params in MoE are already DTensors (from _replicate_other_params) with a single [Replicate()] on ep_mesh, so _fully_shard just adds them to ignored_params without re-distributing. But if a future change distributes a param as [Replicate(), Replicate()] on the 2D model mesh and it's trainable, its gradient will only be reduced across one mesh dimension.

Consider either:

  1. Reducing across all Replicate dimensions in a loop, or
  2. Adding an assertion / comment that only a single Replicate placement is expected here.
Suggested change
if isinstance(param, DTensor):
replicate_placements = [p for p in param.placements if isinstance(p, Replicate)]
if replicate_placements:
if isinstance(param, DTensor):
replicate_placements = [p for p in param.placements if isinstance(p, Replicate)]
# NOTE: currently all trainable Replicate params have exactly one
# Replicate placement (from _replicate_other_params on ep_mesh).
# If multi-Replicate params appear, this needs to reduce across
# every Replicate dimension.
if replicate_placements:

bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4
# TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32,
# currently it could only be specified in HFSaveCfg
fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — Same regex dot-escaping concern applies here. Patterns from fp32_keys_pattern are used as regex via re.search. If pattern authors forget to escape dots, matches will be overly broad. Consider documenting that values are regex (not glob/fnmatch) in the field docstring, or normalizing dots on behalf of the user.

continue
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], dtype))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Pre-existing hardcoded .bfloat16() replaced, but only for one of two callers.

Good fix — the original .bfloat16() ignored the save_dtype parameter. However, the same pattern also exists in _get_fused_hf_param. Make sure that path is covered too (it appears to be at line ~945 in the diff — just confirming this is intentional and not missed).

Comment on lines +370 to +382
dist_param = nn.Parameter(
distribute_tensor(
param,
self.world_mesh,
[Replicate() for _ in range(self.world_mesh.ndim)]
),
requires_grad=param.requires_grad
)
module.register_parameter(name, dist_param)
ignored_params.add(dist_param)
else:
ignored_params.add(param)
break
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: The _fully_shard wrapper is well-designed — collecting fp32-pattern params into ignored_params and distributing them as Replicate DTensors is the right approach to keep FSDP from sharding them.

One subtle note: when a param is already a DTensor (e.g., after _replicate_other_params in MoE), the else branch just adds it to ignored_params without re-distributing on self.world_mesh. This is correct today because MoE's _replicate_other_params already distributed it as [Replicate()] on ep_mesh. But if the caller order ever changes (fp32 distribution before EP distribution), this would silently skip the distribution. A comment noting this assumption would help.

Comment on lines +18 to +21
import tempfile
from pathlib import Path
import json
from safetensors import safe_open
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — Import placement. Per convention in this file, stdlib and third-party imports should be grouped with the existing imports at the top (lines 1–15), not appended after.

@claude
Copy link

claude bot commented Mar 20, 2026

Claude: ## Summary

This PR adds fp32 parameter preservation during FSDP save/load by introducing fp32_keys_pattern matching, a _fully_shard() wrapper that excludes matched params via ignored_params, and per-param save dtype selection. It also fixes several related bugs: world_mesh property setter, hardcoded .bfloat16() in save paths, gradient reduction mesh selection for replicated params, and requires_grad preservation in EP distribution.

Issues

Warning

  • moe.py:876-878scale_and_reduce_grad only reduces across the first Replicate placement. If a param ever has multiple Replicate placements (e.g. [Replicate(), Replicate()] on the 2D world mesh), subsequent dimensions are silently skipped. Not triggered today, but fragile.
  • base.py:990 — Confirm the parallel .bfloat16() fix in _get_fused_hf_param is also covered.

Nit

Verdict

COMMENT — The core approach is sound. The _fully_shard wrapper with ignored_params is a clean way to preserve fp32 params through FSDP. The gradient reduction change in scale_and_reduce_grad works correctly for the current use case but has a latent fragility worth acknowledging with a comment.

@HAOCHENYE HAOCHENYE force-pushed the fp32-param-train branch 3 times, most recently from 7980909 to ac307bf Compare March 20, 2026 16:41
@HAOCHENYE
Copy link
Collaborator Author

@claude review

Comment on lines +364 to +371
for name, param in module.named_parameters(recurse=False):
full_name = full_param_name_mapping[id(param)]
full_name = self._clean_param_name(full_name)
hf_name_list = self.to_hf_key_list(full_name)

for hf_name in hf_name_list:
if any(re.search(p, hf_name) for p in patterns): # type: ignore
if not isinstance(param, DTensor):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — _split_ignored_params yield can produce empty lists.

In _get_shard_hf_param, _get_fused_hf_param, and _get_same_hf_param, after calling self._split_ignored_params(params), the code unconditionally yields name_list, hf_params from the ignored params. If there are no ignored params (i.e., fp32_keys_pattern is None or no pattern matches), this yields ([], []).

Downstream in _save_hf, the fused_gen path has a guard (if not name_list: continue) but the chain(same_gen, shard_gen) path does not — this could cause empty safetensor files to be written and gaps in file numbering.

Consider guarding the yield:

Suggested change
for name, param in module.named_parameters(recurse=False):
full_name = full_param_name_mapping[id(param)]
full_name = self._clean_param_name(full_name)
hf_name_list = self.to_hf_key_list(full_name)
for hf_name in hf_name_list:
if any(re.search(p, hf_name) for p in patterns): # type: ignore
if not isinstance(param, DTensor):
ignored_params, params = self._split_ignored_params(params)
name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
if name_list:
yield name_list, hf_params

The same pattern should be applied in _get_fused_hf_param and _get_same_hf_param.

Comment on lines +877 to 896
param.device_mesh.mesh_dim_names[i]
for i, p in enumerate(param.placements)
if isinstance(p, Replicate)
)
if replicate_dim_names:
# `DeviceMesh.get_group()` only supports a single mesh dimension,
# so calling it directly on a multi-dim sub-mesh raises RuntimeError.
# `_flatten()` collapses all Replicate dims into a 1D mesh whose
# process group covers every rank across those dimensions, allowing
# a single all_reduce regardless of how many Replicate dims exist.
flat_mesh = param.device_mesh[replicate_dim_names]._flatten()
grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad
dist.all_reduce(
grad.div_(flat_mesh.size()), # type: ignore
ReduceOp.SUM,
group=flat_mesh.get_group(), # type: ignore
)

def _init_device_mesh(self, fsdp_config: FSDPConfig):
self.fsdp_config = fsdp_config
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Non-DTensor params with ep_enabled=True now silently skip gradient reduction.

The old code reduced all non-expert params when ep_enabled=True:

if ep_enabled:
    grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad
    dist.all_reduce(grad.div_(...), ...)

The new code only reduces DTensor params that have Replicate placements. If ep_enabled=True but a param is not a DTensor (shouldn't happen in normal FSDP, but defensive coding matters), its gradient is silently dropped from reduction.

Also, DTensor params with only Shard placements and no Replicate get replicate_dim_names = (), so their gradients are also not reduced. Confirm that no such params exist in practice.

The _flatten() approach for multi-dim Replicate meshes is a good improvement over the original hardcoded mesh_dim=0.

from xtuner.v1.config import FSDPConfig
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh

import tempfile
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Two issues with test_save_hf_with_mtp:

  1. Warning — Test does not verify fp32 dtype preservation. The test checks torch.equal(origin_tensor, saved_tensor) for value equality, but never asserts that parameters matching fp32_keys_pattern (e.g., linear_attn.norm.weight, linear_attn.A_log) are saved in torch.float32 dtype. This is the core feature of the PR — consider adding dtype assertions.

  2. Nit — cache_save_fh dict uses safetensor filenames as keys. Both origin and saved dirs may produce the same filenames, causing key collisions. Use full file paths as dict keys instead.

@claude
Copy link

claude bot commented Mar 22, 2026

Claude: ## Summary

This PR adds fp32 parameter preservation during FSDP save/load via fp32_keys_pattern in HFSaveCfg, a _fully_shard() wrapper that distributes matched params as Replicate DTensors (excluded from FSDP via ignored_params), and per-param dtype selection at save time. Also fixes: world_mesh property setter, gradient reduction mesh selection, requires_grad preservation in EP distribution, and hardcoded .bfloat16() in save paths.

Issues

Warning

  • base.py:364-371_split_ignored_params yields ([], []) when no fp32 patterns match. _save_hf guards against empty yields for fused_gen but NOT for chain(same_gen, shard_gen), which could write empty safetensor files and cause file numbering gaps.
  • moe.py:877-896 — Non-DTensor params (or DTensor params without Replicate placements) now skip gradient reduction entirely when ep_enabled=True. The old code reduced all non-expert params unconditionally. Confirm no such params exist in practice.
  • test_qwen3_5.pytest_save_hf_with_mtp does not verify that fp32 params (linear_attn.norm.weight, linear_attn.A_log) are actually saved in torch.float32. Add dtype assertions to validate the core feature. Also, cache_save_fh uses safetensor filenames as dict keys — potential collision between origin and saved dirs.

Nit

  • (Already noted in prior review) Regex patterns in qwen3_5_text.py use unescaped dots. The docstring on fp32_keys_pattern correctly warns about this, but the actual patterns are already properly escaped with \. — the prior nit may have been a false positive.

Verdict

COMMENT — The _fully_shard wrapper with ignored_params is a clean design. The main concern is the unconditional yield of potentially empty lists from _split_ignored_params in all three _get_*_hf_param methods. Adding a guard (if name_list: yield ...) would prevent empty safetensor writes. The test should verify dtype preservation, not just value equality.

- Add `fp32_keys_pattern` to `HFSaveCfg` to specify params that should
  be saved in fp32 regardless of the global save dtype
- Add `_fully_shard()` to `BaseModel` which wraps `fully_shard()` and
  distributes matched params as Replicate DTensors so FSDP ignores them
- Replace all direct `fully_shard()` call sites with `self._fully_shard()`
- Add `_get_save_dtype()` to select fp32 per-param at save time
- Fix load path in `_load_same_hf_param` to skip shard-offset logic for
  Replicate DTensors (only apply for Shard-placed params)
- Fix `world_mesh` property setter bug: use `_world_mesh` directly
  instead of assigning through the read-only property
- Fix gradient reduce in `MoE.scale_and_reduce_grad` to correctly
  identify the Replicate mesh dimension via DTensor placements
- Preserve `requires_grad` when distributing params in EP and fp32 paths
- Configure `Qwen3_5_VLTextMoEConfig` with fp32 patterns for
  `linear_attn.norm.weight` and `linear_attn.A_log`
- Add `test_save_hf_with_mtp` to verify round-trip weight preservation
@HAOCHENYE HAOCHENYE merged commit b0fdc8d into InternLM:main Mar 23, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants