Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from xtuner.v1.config import FSDPConfig
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh

import tempfile
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Two issues with test_save_hf_with_mtp:

  1. Warning — Test does not verify fp32 dtype preservation. The test checks torch.equal(origin_tensor, saved_tensor) for value equality, but never asserts that parameters matching fp32_keys_pattern (e.g., linear_attn.norm.weight, linear_attn.A_log) are saved in torch.float32 dtype. This is the core feature of the PR — consider adding dtype assertions.

  2. Nit — cache_save_fh dict uses safetensor filenames as keys. Both origin and saved dirs may produce the same filenames, causing key collisions. Use full file paths as dict keys instead.

from pathlib import Path
import json
from safetensors import safe_open
Comment on lines +18 to +21
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — Import placement. Per convention in this file, stdlib and third-party imports should be grouped with the existing imports at the top (lines 1–15), not appended after.



VIDEO_ROOT = os.environ["VIDEO_ROOT"]

Expand Down Expand Up @@ -216,6 +221,82 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol):
self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol))
self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol))

@parametrize.parametrize(
"device,sp_size",
[
("cuda", 1),
],
)
def test_save_hf_with_mtp(self, device, sp_size):
self.create_pg(device)
QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]

with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

fsdp_config = FSDPConfig(cpu_offload=False)
fsdp_mesh = init_world_mesh()
qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh
qwen3vl_model.vision_tower.fsdp_config = fsdp_config
qwen3vl_model.fully_shard(fsdp_config=fsdp_config)

with tempfile.TemporaryDirectory() as tmpdir:
syncdir = [tmpdir]
dist.broadcast_object_list(syncdir, src=0)
tmpdir = Path(syncdir[0])
qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
qwen3vl_model.save_hf(tmpdir)

origin_hf_path = Path(QWEN3_VL_MOE_PATH)
origin_index_path = origin_hf_path / "model.safetensors.index.json"
saved_index_path = tmpdir / "model.safetensors.index.json"

if dist.get_rank() == 0:
with open(origin_index_path, "r") as f:
origin_index = json.load(f)
with open(saved_index_path, "r") as f:
saved_index = json.load(f)

cache_save_fh: dict = {}

# Verify all original HF weights are preserved correctly
for key in origin_index["weight_map"].keys():
if "mtp" in key:
continue # TODO: remove this after MTP is implemented
origin_safetensor_name = origin_index["weight_map"][key]
saved_safetensor_name = saved_index["weight_map"][key]

origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name)
saved_sf_fh_name = str(tmpdir / saved_safetensor_name)

if origin_sf_fh_name not in cache_save_fh:
cache_save_fh[origin_sf_fh_name] = safe_open(origin_sf_fh_name, framework="pt")
if saved_sf_fh_name not in cache_save_fh:
cache_save_fh[saved_sf_fh_name] = safe_open(saved_sf_fh_name, framework="pt")

origin_tensor = cache_save_fh[origin_sf_fh_name].get_tensor(key)
saved_tensor = cache_save_fh[saved_sf_fh_name].get_tensor(key)

self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"Tensor mismatch for key: {key}")

# Verify MTP weights are present in the saved output
mtp_keys = [key for key in saved_index["weight_map"].keys() if key.startswith("mtp.")]
# TODO: remove skip after MTP is implemented
_ = mtp_keys

# Verify the tensor count in safetensors matches the saved index
safetensor_keys: list[str] = []
for safetensor_path in tmpdir.glob("*.safetensors"):
fh = safe_open(str(safetensor_path), framework="pt")
safetensor_keys.extend(fh.keys())
safetensor_keys.sort()
model_index_keys = list(saved_index["weight_map"].keys())
model_index_keys.sort()
self.assertListEqual(safetensor_keys, model_index_keys)

dist.barrier()

@property
def world_size(self) -> int:
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "4"))
146 changes: 137 additions & 9 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.tensor import DTensor, Placement, Shard
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from typing_extensions import NotRequired, Self, TypedDict, overload

Expand Down Expand Up @@ -82,6 +82,12 @@ class HFSaveCfg(PydanticBaseModel):
worker_per_rank: Annotated[int, Parameter(group="model")] = 16
max_save_rank: Annotated[int, Parameter(group="model")] = 16
bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4
# TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32,
# currently it could only be specified in HFSaveCfg
# Each entry is a **regex** pattern (passed to `re.search`) matched against the HF parameter name.
# Remember to escape literal dots, e.g. use r"model\.layers\.\d+\.weight" instead of
# r"model.layers.\d+.weight" to avoid unintended wildcard matches.
fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — Same regex dot-escaping concern applies here. Patterns from fp32_keys_pattern are used as regex via re.search. If pattern authors forget to escape dots, matches will be overly broad. Consider documenting that values are regex (not glob/fnmatch) in the field docstring, or normalizing dots on behalf of the user.



class XTunerBaseModelConfig(PydanticBaseModel):
Expand Down Expand Up @@ -313,6 +319,7 @@ def fully_shard(
"""Fully shard the model parameters."""
self.fsdp_config = fsdp_config
self.fsdp_mesh = self._init_world_mesh()
self._world_mesh = self.fsdp_mesh

if self.fsdp_config.requires_grad:
for name, module in self.named_modules():
Expand All @@ -337,15 +344,79 @@ def fully_shard(
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
return self

def _fully_shard(
self,
mesh: DeviceMesh,
mp_policy: MixedPrecisionPolicy,
reshard_after_forward: bool,
offload_policy: CPUOffloadPolicy | None,
module: nn.Module | None = None,
) -> None:
def traverse(module):
for name, param in module.named_parameters(recurse=False):
full_name = full_param_name_mapping[id(param)]
full_name = self._clean_param_name(full_name)
hf_name_list = self.to_hf_key_list(full_name)

for hf_name in hf_name_list:
if any(re.search(p, hf_name) for p in patterns): # type: ignore
if not isinstance(param, DTensor):
Comment on lines +364 to +371
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — _split_ignored_params yield can produce empty lists.

In _get_shard_hf_param, _get_fused_hf_param, and _get_same_hf_param, after calling self._split_ignored_params(params), the code unconditionally yields name_list, hf_params from the ignored params. If there are no ignored params (i.e., fp32_keys_pattern is None or no pattern matches), this yields ([], []).

Downstream in _save_hf, the fused_gen path has a guard (if not name_list: continue) but the chain(same_gen, shard_gen) path does not — this could cause empty safetensor files to be written and gaps in file numbering.

Consider guarding the yield:

Suggested change
for name, param in module.named_parameters(recurse=False):
full_name = full_param_name_mapping[id(param)]
full_name = self._clean_param_name(full_name)
hf_name_list = self.to_hf_key_list(full_name)
for hf_name in hf_name_list:
if any(re.search(p, hf_name) for p in patterns): # type: ignore
if not isinstance(param, DTensor):
ignored_params, params = self._split_ignored_params(params)
name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
if name_list:
yield name_list, hf_params

The same pattern should be applied in _get_fused_hf_param and _get_same_hf_param.

dist_param = nn.Parameter(
distribute_tensor(
param, self.world_mesh, [Replicate() for _ in range(self.world_mesh.ndim)]
),
requires_grad=param.requires_grad,
)
module.register_parameter(name, dist_param)
ignored_params.add(dist_param)
else:
# param is already a DTensor (e.g. distributed by
# MoE._replicate_other_params on ep_mesh before _fully_shard
# is called). We skip re-distributing on world_mesh and just
# add it to ignored_params so FSDP leaves it alone.
# ASSUMPTION: fp32 distribution always happens AFTER any
# prior EP distribution, so the existing placement is correct.
ignored_params.add(param)
break
Comment on lines +372 to +388
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: The _fully_shard wrapper is well-designed — collecting fp32-pattern params into ignored_params and distributing them as Replicate DTensors is the right approach to keep FSDP from sharding them.

One subtle note: when a param is already a DTensor (e.g., after _replicate_other_params in MoE), the else branch just adds it to ignored_params without re-distributing on self.world_mesh. This is correct today because MoE's _replicate_other_params already distributed it as [Replicate()] on ep_mesh. But if the caller order ever changes (fp32 distribution before EP distribution), this would silently skip the distribution. A comment noting this assumption would help.


for child in module.children():
traverse(child)

# Collect the parameters of `target` that match any fp32 pattern so they can be
# excluded from FSDP sharding (passed as `ignored_params`).
#
# We intentionally iterate over `self.named_parameters()` rather than
# `target.named_parameters()` so that `name` is always relative to the root model
# (`self`). This matters when `target` is a sub-module (e.g. `self.embed_tokens`):
# `target.named_parameters()` would yield bare names like `"weight"`, which
# `to_hf_key_list` cannot resolve correctly. By iterating from `self` we get the
# full path (e.g. `"embed_tokens.weight"`) and filter to `target`'s parameters
# using identity comparison.
full_param_name_mapping = {id(param): name for name, param in self.named_parameters()}
ignored_params: set[nn.Parameter] = set()
patterns = self.config.hf_save_cfg.fp32_keys_pattern

target = module or self
if patterns:
traverse(target)

fully_shard(
target,
mesh=mesh,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
ignored_params=ignored_params if ignored_params else None,
)

def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"):
with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"):
self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix)
Expand Down Expand Up @@ -396,6 +467,12 @@ def device(self) -> torch.device:
return torch.device("cpu")
return torch.device(DEVICE)

@property
def world_mesh(self) -> DeviceMesh | None:
if not hasattr(self, "_world_mesh"):
self._world_mesh = self._init_world_mesh()
return self._world_mesh

@property
def default_compile_cfg(self) -> dict[str, TorchCompileOption]:
return {}
Expand Down Expand Up @@ -670,6 +747,12 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat
)
return ret

def _get_save_dtype(self, name: str, dtype: torch.dtype) -> torch.dtype:
patterns = self.config.hf_save_cfg.fp32_keys_pattern
if patterns and any(re.search(p, name) for p in patterns):
return torch.float32
return dtype

def _get_shard_hf_param(
self,
params: list[tuple[torch.Tensor, LoadSpec]],
Expand All @@ -679,6 +762,16 @@ def _get_shard_hf_param(
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
if not params:
return

ignored_params, params = self._split_ignored_params(params)
if ignored_params:
name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
yield name_list, hf_params

if not params:
return

if dtype != torch.bfloat16:
raise NotImplementedError

Expand All @@ -696,7 +789,7 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis
# Get unsharded params
_unsharded_tensor_list = foreach_all_gather(fsdp_unsharded_tensor_list, load_spec0.group)
unsharded_tensor_list = [
torch.cat([i.to(dtype) for i in tensors], dim=load_spec0.dim) for tensors in _unsharded_tensor_list
torch.cat(list(tensors), dim=load_spec0.dim) for tensors in _unsharded_tensor_list
]
name_list = [spec.hf_keys[0] for _, spec in fsdp_tensor_list]
unsharded_tensor_list = [
Expand All @@ -711,11 +804,11 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis

safetensor_size = 0
tensor_list: list[tuple[torch.Tensor, LoadSpec]] = []
name_list: list[str] = []
name_list = []

for param, load_spec in params:
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.to(dtype=dtype)
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
hf_params = _get_hf_params(tensor_list)
Expand Down Expand Up @@ -744,6 +837,12 @@ def _get_fused_hf_param(
if not params:
return

ignored_params, params = self._split_ignored_params(params)
if ignored_params:
fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
fp32_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
yield fp32_name_list, fp32_params

def _get_hf_params(
fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]],
name_list: list[str],
Expand Down Expand Up @@ -867,7 +966,7 @@ def _get_hf_params(

for param, load_spec in params:
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
hf_params, name_list = _get_hf_params(tensor_list, name_list)
Expand All @@ -893,6 +992,15 @@ def _get_same_hf_param(
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
if not params:
return

ignored_params, params = self._split_ignored_params(params)
if ignored_params:
fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
fp32_tensor_list: list[torch.Tensor] = [
param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params
]
yield fp32_name_list, fp32_tensor_list

if bucket_size is None:
bucket_size = self.config.hf_save_cfg.bucket_size
safetensor_size = 0
Expand All @@ -909,7 +1017,7 @@ def _get_same_hf_param(
buffer_name_list.append(load_spec.hf_keys[0])
continue
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
if self.fsdp_mesh is not None:
Expand Down Expand Up @@ -953,6 +1061,21 @@ def _get_same_hf_param(
if buffer_tensor_list:
yield buffer_name_list, buffer_tensor_list

def _is_ignored_params(self, key: str):
patterns = self.config.hf_save_cfg.fp32_keys_pattern
if patterns is None:
return False
return any(re.search(p, key) for p in patterns)

def _split_ignored_params(
self, params: list[tuple[torch.Tensor, LoadSpec]]
) -> tuple[list[tuple[torch.Tensor, LoadSpec]], list[tuple[torch.Tensor, LoadSpec]]]:
if not self.config.hf_save_cfg.fp32_keys_pattern:
return [], params
ignored_params = [(p, l) for p, l in params if self._is_ignored_params(l.hf_keys[0])]
remaining = [(p, l) for p, l in params if not self._is_ignored_params(l.hf_keys[0])]
return ignored_params, remaining

# TODO: Using `xtuenr.v1.utils.misc.clean_param_name`
def _clean_param_name(self, name: str) -> str:
if "_checkpoint_wrapped_module." in name:
Expand Down Expand Up @@ -1230,7 +1353,12 @@ def _load_same_hf_param(

loaded_tensor = loaded_tensor.to(local_tensor.device)

if self.fsdp_mesh is not None and isinstance(param, nn.Parameter):
if (
self.fsdp_mesh is not None
and isinstance(param, nn.Parameter)
and isinstance(param, DTensor)
and any(isinstance(p, Shard) for p in param.placements)
):
shape_before_fsdp = load_spec.shape
_, _offset = compute_local_shape_and_global_offset(
shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)]
Expand Down
4 changes: 1 addition & 3 deletions xtuner/v1/model/compose/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
from typing_extensions import override

Expand Down Expand Up @@ -108,8 +107,7 @@ def fully_shard(
# Note: 非常关键,不能删除这个 assert
assert self.fsdp_mesh is not None

fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
Expand Down
3 changes: 1 addition & 2 deletions xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def fully_shard(
# Note: 非常关键,不能删除这个 assert
assert self.fsdp_mesh is not None

fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
Expand Down
3 changes: 1 addition & 2 deletions xtuner/v1/model/compose/intern_s1/modeling_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def fully_shard(
for param in self.parameters():
param.requires_grad = False

fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=True,
Expand Down
Loading
Loading