diff --git a/tests/model/test_qwen3_5.py b/tests/model/test_qwen3_5.py index afe348aa9..0431d3640 100644 --- a/tests/model/test_qwen3_5.py +++ b/tests/model/test_qwen3_5.py @@ -15,6 +15,11 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh +import tempfile +from pathlib import Path +import json +from safetensors import safe_open + VIDEO_ROOT = os.environ["VIDEO_ROOT"] @@ -216,6 +221,82 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol): self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol)) self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol)) + @parametrize.parametrize( + "device,sp_size", + [ + ("cuda", 1), + ], + ) + def test_save_hf_with_mtp(self, device, sp_size): + self.create_pg(device) + QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"] + + with torch.device("meta"): + model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False) + qwen3vl_model = model_cfg.build().to(torch.bfloat16) + + fsdp_config = FSDPConfig(cpu_offload=False) + fsdp_mesh = init_world_mesh() + qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh + qwen3vl_model.vision_tower.fsdp_config = fsdp_config + qwen3vl_model.fully_shard(fsdp_config=fsdp_config) + + with tempfile.TemporaryDirectory() as tmpdir: + syncdir = [tmpdir] + dist.broadcast_object_list(syncdir, src=0) + tmpdir = Path(syncdir[0]) + qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH) + qwen3vl_model.save_hf(tmpdir) + + origin_hf_path = Path(QWEN3_VL_MOE_PATH) + origin_index_path = origin_hf_path / "model.safetensors.index.json" + saved_index_path = tmpdir / "model.safetensors.index.json" + + if dist.get_rank() == 0: + with open(origin_index_path, "r") as f: + origin_index = json.load(f) + with open(saved_index_path, "r") as f: + saved_index = json.load(f) + + cache_save_fh: dict = {} + + # Verify all original HF weights are preserved correctly + for key in origin_index["weight_map"].keys(): + if "mtp" in key: + continue # TODO: remove this after MTP is implemented + origin_safetensor_name = origin_index["weight_map"][key] + saved_safetensor_name = saved_index["weight_map"][key] + + origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name) + saved_sf_fh_name = str(tmpdir / saved_safetensor_name) + + if origin_sf_fh_name not in cache_save_fh: + cache_save_fh[origin_sf_fh_name] = safe_open(origin_sf_fh_name, framework="pt") + if saved_sf_fh_name not in cache_save_fh: + cache_save_fh[saved_sf_fh_name] = safe_open(saved_sf_fh_name, framework="pt") + + origin_tensor = cache_save_fh[origin_sf_fh_name].get_tensor(key) + saved_tensor = cache_save_fh[saved_sf_fh_name].get_tensor(key) + + self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"Tensor mismatch for key: {key}") + + # Verify MTP weights are present in the saved output + mtp_keys = [key for key in saved_index["weight_map"].keys() if key.startswith("mtp.")] + # TODO: remove skip after MTP is implemented + _ = mtp_keys + + # Verify the tensor count in safetensors matches the saved index + safetensor_keys: list[str] = [] + for safetensor_path in tmpdir.glob("*.safetensors"): + fh = safe_open(str(safetensor_path), framework="pt") + safetensor_keys.extend(fh.keys()) + safetensor_keys.sort() + model_index_keys = list(saved_index["weight_map"].keys()) + model_index_keys.sort() + self.assertListEqual(safetensor_keys, model_index_keys) + + dist.barrier() + @property def world_size(self) -> int: return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "4")) diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 22f2c7139..071cd4ebc 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -26,7 +26,7 @@ MixedPrecisionPolicy, fully_shard, ) -from torch.distributed.tensor import DTensor, Placement, Shard +from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from typing_extensions import NotRequired, Self, TypedDict, overload @@ -82,6 +82,12 @@ class HFSaveCfg(PydanticBaseModel): worker_per_rank: Annotated[int, Parameter(group="model")] = 16 max_save_rank: Annotated[int, Parameter(group="model")] = 16 bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4 + # TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32, + # currently it could only be specified in HFSaveCfg + # Each entry is a **regex** pattern (passed to `re.search`) matched against the HF parameter name. + # Remember to escape literal dots, e.g. use r"model\.layers\.\d+\.weight" instead of + # r"model.layers.\d+.weight" to avoid unintended wildcard matches. + fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None class XTunerBaseModelConfig(PydanticBaseModel): @@ -313,6 +319,7 @@ def fully_shard( """Fully shard the model parameters.""" self.fsdp_config = fsdp_config self.fsdp_mesh = self._init_world_mesh() + self._world_mesh = self.fsdp_mesh if self.fsdp_config.requires_grad: for name, module in self.named_modules(): @@ -337,8 +344,7 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=fsdp_config.reshard_after_forward, @@ -346,6 +352,71 @@ def fully_shard( ) return self + def _fully_shard( + self, + mesh: DeviceMesh, + mp_policy: MixedPrecisionPolicy, + reshard_after_forward: bool, + offload_policy: CPUOffloadPolicy | None, + module: nn.Module | None = None, + ) -> None: + def traverse(module): + for name, param in module.named_parameters(recurse=False): + full_name = full_param_name_mapping[id(param)] + full_name = self._clean_param_name(full_name) + hf_name_list = self.to_hf_key_list(full_name) + + for hf_name in hf_name_list: + if any(re.search(p, hf_name) for p in patterns): # type: ignore + if not isinstance(param, DTensor): + dist_param = nn.Parameter( + distribute_tensor( + param, self.world_mesh, [Replicate() for _ in range(self.world_mesh.ndim)] + ), + requires_grad=param.requires_grad, + ) + module.register_parameter(name, dist_param) + ignored_params.add(dist_param) + else: + # param is already a DTensor (e.g. distributed by + # MoE._replicate_other_params on ep_mesh before _fully_shard + # is called). We skip re-distributing on world_mesh and just + # add it to ignored_params so FSDP leaves it alone. + # ASSUMPTION: fp32 distribution always happens AFTER any + # prior EP distribution, so the existing placement is correct. + ignored_params.add(param) + break + + for child in module.children(): + traverse(child) + + # Collect the parameters of `target` that match any fp32 pattern so they can be + # excluded from FSDP sharding (passed as `ignored_params`). + # + # We intentionally iterate over `self.named_parameters()` rather than + # `target.named_parameters()` so that `name` is always relative to the root model + # (`self`). This matters when `target` is a sub-module (e.g. `self.embed_tokens`): + # `target.named_parameters()` would yield bare names like `"weight"`, which + # `to_hf_key_list` cannot resolve correctly. By iterating from `self` we get the + # full path (e.g. `"embed_tokens.weight"`) and filter to `target`'s parameters + # using identity comparison. + full_param_name_mapping = {id(param): name for name, param in self.named_parameters()} + ignored_params: set[nn.Parameter] = set() + patterns = self.config.hf_save_cfg.fp32_keys_pattern + + target = module or self + if patterns: + traverse(target) + + fully_shard( + target, + mesh=mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + offload_policy=offload_policy, + ignored_params=ignored_params if ignored_params else None, + ) + def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"): with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"): self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix) @@ -396,6 +467,12 @@ def device(self) -> torch.device: return torch.device("cpu") return torch.device(DEVICE) + @property + def world_mesh(self) -> DeviceMesh | None: + if not hasattr(self, "_world_mesh"): + self._world_mesh = self._init_world_mesh() + return self._world_mesh + @property def default_compile_cfg(self) -> dict[str, TorchCompileOption]: return {} @@ -670,6 +747,12 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat ) return ret + def _get_save_dtype(self, name: str, dtype: torch.dtype) -> torch.dtype: + patterns = self.config.hf_save_cfg.fp32_keys_pattern + if patterns and any(re.search(p, name) for p in patterns): + return torch.float32 + return dtype + def _get_shard_hf_param( self, params: list[tuple[torch.Tensor, LoadSpec]], @@ -679,6 +762,16 @@ def _get_shard_hf_param( ) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]: if not params: return + + ignored_params, params = self._split_ignored_params(params) + if ignored_params: + name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] + hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params] + yield name_list, hf_params + + if not params: + return + if dtype != torch.bfloat16: raise NotImplementedError @@ -696,7 +789,7 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis # Get unsharded params _unsharded_tensor_list = foreach_all_gather(fsdp_unsharded_tensor_list, load_spec0.group) unsharded_tensor_list = [ - torch.cat([i.to(dtype) for i in tensors], dim=load_spec0.dim) for tensors in _unsharded_tensor_list + torch.cat(list(tensors), dim=load_spec0.dim) for tensors in _unsharded_tensor_list ] name_list = [spec.hf_keys[0] for _, spec in fsdp_tensor_list] unsharded_tensor_list = [ @@ -711,11 +804,11 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis safetensor_size = 0 tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - name_list: list[str] = [] + name_list = [] for param, load_spec in params: local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.to(dtype=dtype) + local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16)) tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: hf_params = _get_hf_params(tensor_list) @@ -744,6 +837,12 @@ def _get_fused_hf_param( if not params: return + ignored_params, params = self._split_ignored_params(params) + if ignored_params: + fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] + fp32_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params] + yield fp32_name_list, fp32_params + def _get_hf_params( fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], name_list: list[str], @@ -867,7 +966,7 @@ def _get_hf_params( for param, load_spec in params: local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() + local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16)) tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: hf_params, name_list = _get_hf_params(tensor_list, name_list) @@ -893,6 +992,15 @@ def _get_same_hf_param( ) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]: if not params: return + + ignored_params, params = self._split_ignored_params(params) + if ignored_params: + fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] + fp32_tensor_list: list[torch.Tensor] = [ + param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params + ] + yield fp32_name_list, fp32_tensor_list + if bucket_size is None: bucket_size = self.config.hf_save_cfg.bucket_size safetensor_size = 0 @@ -909,7 +1017,7 @@ def _get_same_hf_param( buffer_name_list.append(load_spec.hf_keys[0]) continue local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() + local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16)) tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: if self.fsdp_mesh is not None: @@ -953,6 +1061,21 @@ def _get_same_hf_param( if buffer_tensor_list: yield buffer_name_list, buffer_tensor_list + def _is_ignored_params(self, key: str): + patterns = self.config.hf_save_cfg.fp32_keys_pattern + if patterns is None: + return False + return any(re.search(p, key) for p in patterns) + + def _split_ignored_params( + self, params: list[tuple[torch.Tensor, LoadSpec]] + ) -> tuple[list[tuple[torch.Tensor, LoadSpec]], list[tuple[torch.Tensor, LoadSpec]]]: + if not self.config.hf_save_cfg.fp32_keys_pattern: + return [], params + ignored_params = [(p, l) for p, l in params if self._is_ignored_params(l.hf_keys[0])] + remaining = [(p, l) for p, l in params if not self._is_ignored_params(l.hf_keys[0])] + return ignored_params, remaining + # TODO: Using `xtuenr.v1.utils.misc.clean_param_name` def _clean_param_name(self, name: str) -> str: if "_checkpoint_wrapped_module." in name: @@ -1230,7 +1353,12 @@ def _load_same_hf_param( loaded_tensor = loaded_tensor.to(local_tensor.device) - if self.fsdp_mesh is not None and isinstance(param, nn.Parameter): + if ( + self.fsdp_mesh is not None + and isinstance(param, nn.Parameter) + and isinstance(param, DTensor) + and any(isinstance(p, Shard) for p in param.placements) + ): shape_before_fsdp = load_spec.shape _, _offset = compute_local_shape_and_global_offset( shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] diff --git a/xtuner/v1/model/compose/base.py b/xtuner/v1/model/compose/base.py index e878cb457..a2cc24e51 100644 --- a/xtuner/v1/model/compose/base.py +++ b/xtuner/v1/model/compose/base.py @@ -10,7 +10,6 @@ CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, - fully_shard, ) from typing_extensions import override @@ -108,8 +107,7 @@ def fully_shard( # Note: 非常关键,不能删除这个 assert assert self.fsdp_mesh is not None - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=fsdp_config.reshard_after_forward, diff --git a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py index c82028a21..0f4563896 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py @@ -83,8 +83,7 @@ def fully_shard( # Note: 非常关键,不能删除这个 assert assert self.fsdp_mesh is not None - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=fsdp_config.reshard_after_forward, diff --git a/xtuner/v1/model/compose/intern_s1/modeling_projector.py b/xtuner/v1/model/compose/intern_s1/modeling_projector.py index 782a2dd3d..8fda62651 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_projector.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_projector.py @@ -70,8 +70,7 @@ def fully_shard( for param in self.parameters(): param.requires_grad = False - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=True, diff --git a/xtuner/v1/model/compose/intern_s1/modeling_vision.py b/xtuner/v1/model/compose/intern_s1/modeling_vision.py index e263cb58c..9aef186df 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_vision.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_vision.py @@ -393,21 +393,18 @@ def fully_shard( self.encoder.layer[layer_idx] = layer - fully_shard( - layer, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=True, - offload_policy=CPUOffloadPolicy() - if fsdp_config.cpu_offload - else None, + offload_policy=CPUOffloadPolicy() if fsdp_config.cpu_offload else None, + module=layer, ) for layer_cur, layer_next in zip(self.encoder.layer[:-1], self.encoder.layer[1:]): layer_cur.set_modules_to_forward_prefetch([layer_next]) - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=True, diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py b/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py index 2ffb53d9e..338d44686 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py @@ -102,8 +102,7 @@ def fully_shard( for param in self.parameters(): param.requires_grad = False - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=True, diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py index bdf464488..ee7e47a5a 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py @@ -350,21 +350,18 @@ def fully_shard( self.blocks[layer_idx] = layer - fully_shard( - layer, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=decoder_layer_mp_policy, reshard_after_forward=True, - offload_policy=CPUOffloadPolicy() - if fsdp_config.cpu_offload - else None, + offload_policy=CPUOffloadPolicy() if fsdp_config.cpu_offload else None, + module=layer, ) for layer_cur, layer_next in zip(self.blocks[:-1], self.blocks[1:]): layer_cur.set_modules_to_forward_prefetch([layer_next]) - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh, mp_policy=mp_policy, reshard_after_forward=True, diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index f44c8c8de..36c9465a8 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -11,7 +11,6 @@ from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - fully_shard, ) from torch.distributed.tensor import DTensor from tqdm import tqdm @@ -242,12 +241,12 @@ def fully_shard( layer.forward = torch.compile(layer.forward, fullgraph=True) self.layers[str(layer_idx)] = layer - fully_shard( - layer, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=layer, ) for layer_cur, layer_next in zip( @@ -256,32 +255,31 @@ def fully_shard( ): layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore - fully_shard( - self.embed_tokens, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=lm_head_mp_policy if self.config.tie_word_embeddings else mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=self.embed_tokens, ) - fully_shard( - self.norm, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=self.norm, ) - fully_shard( - self.lm_head, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=self.lm_head, ) - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index c7aad8704..18435e1ad 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -18,7 +18,6 @@ from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - fully_shard, ) from torch.distributed.tensor import DTensor, Replicate, distribute_tensor from tqdm import tqdm @@ -794,12 +793,13 @@ def fully_shard( reshard_after_forward = False else: reshard_after_forward = self.fsdp_config.reshard_after_forward - fully_shard( - layer, + + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=layer, ) for layer_cur, layer_next in zip( @@ -808,32 +808,31 @@ def fully_shard( ): layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore - fully_shard( - self.embed_tokens, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=self.embed_tokens, ) - fully_shard( - self.norm, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=self.norm, ) - fully_shard( - self.lm_head, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=self.lm_head, ) - fully_shard( - self, + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, @@ -873,14 +872,25 @@ def scale_and_reduce_grad(self): param.grad.div_(self.ep_mesh.size()) # type: ignore continue - # Reduce gradients for other parameters - if ep_enabled: - grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad - dist.all_reduce( - grad.div_(self.ep_mesh.size()), # type: ignore - ReduceOp.SUM, - group=self.ep_mesh.get_group(mesh_dim=0), # type: ignore + if isinstance(param, DTensor): + replicate_dim_names = tuple( + param.device_mesh.mesh_dim_names[i] + for i, p in enumerate(param.placements) + if isinstance(p, Replicate) ) + if replicate_dim_names: + # `DeviceMesh.get_group()` only supports a single mesh dimension, + # so calling it directly on a multi-dim sub-mesh raises RuntimeError. + # `_flatten()` collapses all Replicate dims into a 1D mesh whose + # process group covers every rank across those dimensions, allowing + # a single all_reduce regardless of how many Replicate dims exist. + flat_mesh = param.device_mesh[replicate_dim_names]._flatten() + grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad + dist.all_reduce( + grad.div_(flat_mesh.size()), # type: ignore + ReduceOp.SUM, + group=flat_mesh.get_group(), # type: ignore + ) def _init_device_mesh(self, fsdp_config: FSDPConfig): self.fsdp_config = fsdp_config @@ -895,6 +905,7 @@ def _init_device_mesh(self, fsdp_config: FSDPConfig): (experts_fsdp_size, self.fsdp_config.ep_size), mesh_dim_names=(f"{self.config.mesh_prefix}.fsdp", f"{self.config.mesh_prefix}.ep"), ) + self._world_mesh = model_mesh if self.ep_mesh is not None: # WARN: This assertion is **VERY** important. # FSDP requires that `device_mesh` shares the same root mesh across all mesh dimensions. @@ -959,7 +970,9 @@ def traverse(module): if isinstance(module, MoEBlock): return for name, param in module.named_parameters(recurse=False): - dist_param = nn.Parameter(distribute_tensor(param, self.ep_mesh, [Replicate()])) + dist_param = nn.Parameter( + distribute_tensor(param, self.ep_mesh, [Replicate()]), requires_grad=param.requires_grad + ) module.register_parameter(name, dist_param) for child in module.children(): traverse(child) diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 3afc4e5dd..6ff218412 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -7,6 +7,7 @@ from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, + HFSaveCfg, TorchCompileOption, ) from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig @@ -145,6 +146,12 @@ def default_compile_cfg(self) -> dict[str, TorchCompileOption]: class Qwen3_5_VLTextMoEConfig(MoEConfig): with_shared_expert_gate: bool = True rms_norm_type: Literal["default", "zero_centered"] = "zero_centered" + hf_save_cfg: HFSaveCfg = HFSaveCfg( + fp32_keys_pattern=[ + r"model\.language_model\.layers\.\d+\.linear_attn\.norm\.weight", + r"model\.language_model\.layers\.\d+\.linear_attn\.A_log", + ], + ) @computed_field def layers_type(self) -> list[Literal["full_attention", "linear_attention"]]: