[Fix] Support fp32 param preservation during FSDP save and load#1607
[Fix] Support fp32 param preservation during FSDP save and load#1607HAOCHENYE merged 1 commit intoInternLM:mainfrom
Conversation
|
@claude review |
xtuner/v1/model/moe/moe.py
Outdated
| if isinstance(param, DTensor): | ||
| replicate_placements = [p for p in param.placements if isinstance(p, Replicate)] | ||
| if replicate_placements: |
There was a problem hiding this comment.
Claude: Warning — Only the first Replicate dimension is reduced.
replicate_placements may contain more than one entry (e.g. for fp32-ignored params distributed as [Replicate(), Replicate()] on the 2D world_mesh). This loop only reduces across the first Replicate dimension, silently skipping subsequent ones.
Today this isn't triggered because matching fp32 params in MoE are already DTensors (from _replicate_other_params) with a single [Replicate()] on ep_mesh, so _fully_shard just adds them to ignored_params without re-distributing. But if a future change distributes a param as [Replicate(), Replicate()] on the 2D model mesh and it's trainable, its gradient will only be reduced across one mesh dimension.
Consider either:
- Reducing across all Replicate dimensions in a loop, or
- Adding an assertion / comment that only a single Replicate placement is expected here.
| if isinstance(param, DTensor): | |
| replicate_placements = [p for p in param.placements if isinstance(p, Replicate)] | |
| if replicate_placements: | |
| if isinstance(param, DTensor): | |
| replicate_placements = [p for p in param.placements if isinstance(p, Replicate)] | |
| # NOTE: currently all trainable Replicate params have exactly one | |
| # Replicate placement (from _replicate_other_params on ep_mesh). | |
| # If multi-Replicate params appear, this needs to reduce across | |
| # every Replicate dimension. | |
| if replicate_placements: |
| bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4 | ||
| # TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32, | ||
| # currently it could only be specified in HFSaveCfg | ||
| fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None |
There was a problem hiding this comment.
Claude: Nit — Same regex dot-escaping concern applies here. Patterns from fp32_keys_pattern are used as regex via re.search. If pattern authors forget to escape dots, matches will be overly broad. Consider documenting that values are regex (not glob/fnmatch) in the field docstring, or normalizing dots on behalf of the user.
xtuner/v1/model/base.py
Outdated
| continue | ||
| local_tensor = param._local_tensor if isinstance(param, DTensor) else param | ||
| local_tensor = local_tensor.bfloat16() | ||
| local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], dtype)) |
There was a problem hiding this comment.
Claude: Warning — Pre-existing hardcoded .bfloat16() replaced, but only for one of two callers.
Good fix — the original .bfloat16() ignored the save_dtype parameter. However, the same pattern also exists in _get_fused_hf_param. Make sure that path is covered too (it appears to be at line ~945 in the diff — just confirming this is intentional and not missed).
| dist_param = nn.Parameter( | ||
| distribute_tensor( | ||
| param, | ||
| self.world_mesh, | ||
| [Replicate() for _ in range(self.world_mesh.ndim)] | ||
| ), | ||
| requires_grad=param.requires_grad | ||
| ) | ||
| module.register_parameter(name, dist_param) | ||
| ignored_params.add(dist_param) | ||
| else: | ||
| ignored_params.add(param) | ||
| break |
There was a problem hiding this comment.
Claude: The _fully_shard wrapper is well-designed — collecting fp32-pattern params into ignored_params and distributing them as Replicate DTensors is the right approach to keep FSDP from sharding them.
One subtle note: when a param is already a DTensor (e.g., after _replicate_other_params in MoE), the else branch just adds it to ignored_params without re-distributing on self.world_mesh. This is correct today because MoE's _replicate_other_params already distributed it as [Replicate()] on ep_mesh. But if the caller order ever changes (fp32 distribution before EP distribution), this would silently skip the distribution. A comment noting this assumption would help.
| import tempfile | ||
| from pathlib import Path | ||
| import json | ||
| from safetensors import safe_open |
There was a problem hiding this comment.
Claude: Nit — Import placement. Per convention in this file, stdlib and third-party imports should be grouped with the existing imports at the top (lines 1–15), not appended after.
|
Claude: ## Summary This PR adds fp32 parameter preservation during FSDP save/load by introducing IssuesWarning
Nit
VerdictCOMMENT — The core approach is sound. The |
7980909 to
ac307bf
Compare
|
@claude review |
| for name, param in module.named_parameters(recurse=False): | ||
| full_name = full_param_name_mapping[id(param)] | ||
| full_name = self._clean_param_name(full_name) | ||
| hf_name_list = self.to_hf_key_list(full_name) | ||
|
|
||
| for hf_name in hf_name_list: | ||
| if any(re.search(p, hf_name) for p in patterns): # type: ignore | ||
| if not isinstance(param, DTensor): |
There was a problem hiding this comment.
Claude: Warning — _split_ignored_params yield can produce empty lists.
In _get_shard_hf_param, _get_fused_hf_param, and _get_same_hf_param, after calling self._split_ignored_params(params), the code unconditionally yields name_list, hf_params from the ignored params. If there are no ignored params (i.e., fp32_keys_pattern is None or no pattern matches), this yields ([], []).
Downstream in _save_hf, the fused_gen path has a guard (if not name_list: continue) but the chain(same_gen, shard_gen) path does not — this could cause empty safetensor files to be written and gaps in file numbering.
Consider guarding the yield:
| for name, param in module.named_parameters(recurse=False): | |
| full_name = full_param_name_mapping[id(param)] | |
| full_name = self._clean_param_name(full_name) | |
| hf_name_list = self.to_hf_key_list(full_name) | |
| for hf_name in hf_name_list: | |
| if any(re.search(p, hf_name) for p in patterns): # type: ignore | |
| if not isinstance(param, DTensor): | |
| ignored_params, params = self._split_ignored_params(params) | |
| name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] | |
| hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params] | |
| if name_list: | |
| yield name_list, hf_params |
The same pattern should be applied in _get_fused_hf_param and _get_same_hf_param.
| param.device_mesh.mesh_dim_names[i] | ||
| for i, p in enumerate(param.placements) | ||
| if isinstance(p, Replicate) | ||
| ) | ||
| if replicate_dim_names: | ||
| # `DeviceMesh.get_group()` only supports a single mesh dimension, | ||
| # so calling it directly on a multi-dim sub-mesh raises RuntimeError. | ||
| # `_flatten()` collapses all Replicate dims into a 1D mesh whose | ||
| # process group covers every rank across those dimensions, allowing | ||
| # a single all_reduce regardless of how many Replicate dims exist. | ||
| flat_mesh = param.device_mesh[replicate_dim_names]._flatten() | ||
| grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad | ||
| dist.all_reduce( | ||
| grad.div_(flat_mesh.size()), # type: ignore | ||
| ReduceOp.SUM, | ||
| group=flat_mesh.get_group(), # type: ignore | ||
| ) | ||
|
|
||
| def _init_device_mesh(self, fsdp_config: FSDPConfig): | ||
| self.fsdp_config = fsdp_config |
There was a problem hiding this comment.
Claude: Warning — Non-DTensor params with ep_enabled=True now silently skip gradient reduction.
The old code reduced all non-expert params when ep_enabled=True:
if ep_enabled:
grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad
dist.all_reduce(grad.div_(...), ...)The new code only reduces DTensor params that have Replicate placements. If ep_enabled=True but a param is not a DTensor (shouldn't happen in normal FSDP, but defensive coding matters), its gradient is silently dropped from reduction.
Also, DTensor params with only Shard placements and no Replicate get replicate_dim_names = (), so their gradients are also not reduced. Confirm that no such params exist in practice.
The _flatten() approach for multi-dim Replicate meshes is a good improvement over the original hardcoded mesh_dim=0.
| from xtuner.v1.config import FSDPConfig | ||
| from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh | ||
|
|
||
| import tempfile |
There was a problem hiding this comment.
Claude: Two issues with test_save_hf_with_mtp:
-
Warning — Test does not verify fp32 dtype preservation. The test checks
torch.equal(origin_tensor, saved_tensor)for value equality, but never asserts that parameters matchingfp32_keys_pattern(e.g.,linear_attn.norm.weight,linear_attn.A_log) are saved intorch.float32dtype. This is the core feature of the PR — consider adding dtype assertions. -
Nit —
cache_save_fhdict uses safetensor filenames as keys. Both origin and saved dirs may produce the same filenames, causing key collisions. Use full file paths as dict keys instead.
|
Claude: ## Summary This PR adds fp32 parameter preservation during FSDP save/load via IssuesWarning
Nit
VerdictCOMMENT — The |
- Add `fp32_keys_pattern` to `HFSaveCfg` to specify params that should be saved in fp32 regardless of the global save dtype - Add `_fully_shard()` to `BaseModel` which wraps `fully_shard()` and distributes matched params as Replicate DTensors so FSDP ignores them - Replace all direct `fully_shard()` call sites with `self._fully_shard()` - Add `_get_save_dtype()` to select fp32 per-param at save time - Fix load path in `_load_same_hf_param` to skip shard-offset logic for Replicate DTensors (only apply for Shard-placed params) - Fix `world_mesh` property setter bug: use `_world_mesh` directly instead of assigning through the read-only property - Fix gradient reduce in `MoE.scale_and_reduce_grad` to correctly identify the Replicate mesh dimension via DTensor placements - Preserve `requires_grad` when distributing params in EP and fp32 paths - Configure `Qwen3_5_VLTextMoEConfig` with fp32 patterns for `linear_attn.norm.weight` and `linear_attn.A_log` - Add `test_save_hf_with_mtp` to verify round-trip weight preservation
ac307bf to
53c385c
Compare
fp32_keys_patterntoHFSaveCfgto specify params that shouldbe saved in fp32 regardless of the global save dtype
_fully_shard()toBaseModelwhich wrapsfully_shard()anddistributes matched params as Replicate DTensors so FSDP ignores them
fully_shard()call sites withself._fully_shard()_get_save_dtype()to select fp32 per-param at save time_load_same_hf_paramto skip shard-offset logic forReplicate DTensors (only apply for Shard-placed params)
world_meshproperty setter bug: use_world_meshdirectlyinstead of assigning through the read-only property
MoE.scale_and_reduce_gradto correctlyidentify the Replicate mesh dimension via DTensor placements
requires_gradwhen distributing params in EP and fp32 pathsQwen3_5_VLTextMoEConfigwith fp32 patterns forlinear_attn.norm.weightandlinear_attn.A_logtest_save_hf_with_mtpto verify round-trip weight preservation