-
Notifications
You must be signed in to change notification settings - Fork 413
[Fix] Support fp32 param preservation during FSDP save and load #1607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,11 @@ | |
| from xtuner.v1.config import FSDPConfig | ||
| from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh | ||
|
|
||
| import tempfile | ||
| from pathlib import Path | ||
| import json | ||
| from safetensors import safe_open | ||
|
Comment on lines
+18
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit — Import placement. Per convention in this file, stdlib and third-party imports should be grouped with the existing imports at the top (lines 1–15), not appended after. |
||
|
|
||
|
|
||
| VIDEO_ROOT = os.environ["VIDEO_ROOT"] | ||
|
|
||
|
|
@@ -216,6 +221,82 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol): | |
| self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol)) | ||
| self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol)) | ||
|
|
||
| @parametrize.parametrize( | ||
| "device,sp_size", | ||
| [ | ||
| ("cuda", 1), | ||
| ], | ||
| ) | ||
| def test_save_hf_with_mtp(self, device, sp_size): | ||
| self.create_pg(device) | ||
| QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"] | ||
|
|
||
| with torch.device("meta"): | ||
| model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False) | ||
| qwen3vl_model = model_cfg.build().to(torch.bfloat16) | ||
|
|
||
| fsdp_config = FSDPConfig(cpu_offload=False) | ||
| fsdp_mesh = init_world_mesh() | ||
| qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh | ||
| qwen3vl_model.vision_tower.fsdp_config = fsdp_config | ||
| qwen3vl_model.fully_shard(fsdp_config=fsdp_config) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| syncdir = [tmpdir] | ||
| dist.broadcast_object_list(syncdir, src=0) | ||
| tmpdir = Path(syncdir[0]) | ||
| qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH) | ||
| qwen3vl_model.save_hf(tmpdir) | ||
|
|
||
| origin_hf_path = Path(QWEN3_VL_MOE_PATH) | ||
| origin_index_path = origin_hf_path / "model.safetensors.index.json" | ||
| saved_index_path = tmpdir / "model.safetensors.index.json" | ||
|
|
||
| if dist.get_rank() == 0: | ||
| with open(origin_index_path, "r") as f: | ||
| origin_index = json.load(f) | ||
| with open(saved_index_path, "r") as f: | ||
| saved_index = json.load(f) | ||
|
|
||
| cache_save_fh: dict = {} | ||
|
|
||
| # Verify all original HF weights are preserved correctly | ||
| for key in origin_index["weight_map"].keys(): | ||
| if "mtp" in key: | ||
| continue # TODO: remove this after MTP is implemented | ||
| origin_safetensor_name = origin_index["weight_map"][key] | ||
| saved_safetensor_name = saved_index["weight_map"][key] | ||
|
|
||
| origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name) | ||
| saved_sf_fh_name = str(tmpdir / saved_safetensor_name) | ||
|
|
||
| if origin_sf_fh_name not in cache_save_fh: | ||
| cache_save_fh[origin_sf_fh_name] = safe_open(origin_sf_fh_name, framework="pt") | ||
| if saved_sf_fh_name not in cache_save_fh: | ||
| cache_save_fh[saved_sf_fh_name] = safe_open(saved_sf_fh_name, framework="pt") | ||
|
|
||
| origin_tensor = cache_save_fh[origin_sf_fh_name].get_tensor(key) | ||
| saved_tensor = cache_save_fh[saved_sf_fh_name].get_tensor(key) | ||
|
|
||
| self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"Tensor mismatch for key: {key}") | ||
|
|
||
| # Verify MTP weights are present in the saved output | ||
| mtp_keys = [key for key in saved_index["weight_map"].keys() if key.startswith("mtp.")] | ||
| # TODO: remove skip after MTP is implemented | ||
| _ = mtp_keys | ||
|
|
||
| # Verify the tensor count in safetensors matches the saved index | ||
| safetensor_keys: list[str] = [] | ||
| for safetensor_path in tmpdir.glob("*.safetensors"): | ||
| fh = safe_open(str(safetensor_path), framework="pt") | ||
| safetensor_keys.extend(fh.keys()) | ||
| safetensor_keys.sort() | ||
| model_index_keys = list(saved_index["weight_map"].keys()) | ||
| model_index_keys.sort() | ||
| self.assertListEqual(safetensor_keys, model_index_keys) | ||
|
|
||
| dist.barrier() | ||
|
|
||
| @property | ||
| def world_size(self) -> int: | ||
| return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "4")) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,7 +26,7 @@ | |||||||||||||||||||||||||||
| MixedPrecisionPolicy, | ||||||||||||||||||||||||||||
| fully_shard, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| from torch.distributed.tensor import DTensor, Placement, Shard | ||||||||||||||||||||||||||||
| from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor | ||||||||||||||||||||||||||||
| from torch.distributed.tensor._utils import compute_local_shape_and_global_offset | ||||||||||||||||||||||||||||
| from typing_extensions import NotRequired, Self, TypedDict, overload | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -82,6 +82,12 @@ class HFSaveCfg(PydanticBaseModel): | |||||||||||||||||||||||||||
| worker_per_rank: Annotated[int, Parameter(group="model")] = 16 | ||||||||||||||||||||||||||||
| max_save_rank: Annotated[int, Parameter(group="model")] = 16 | ||||||||||||||||||||||||||||
| bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4 | ||||||||||||||||||||||||||||
| # TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32, | ||||||||||||||||||||||||||||
| # currently it could only be specified in HFSaveCfg | ||||||||||||||||||||||||||||
| # Each entry is a **regex** pattern (passed to `re.search`) matched against the HF parameter name. | ||||||||||||||||||||||||||||
| # Remember to escape literal dots, e.g. use r"model\.layers\.\d+\.weight" instead of | ||||||||||||||||||||||||||||
| # r"model.layers.\d+.weight" to avoid unintended wildcard matches. | ||||||||||||||||||||||||||||
| fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None | ||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit — Same regex dot-escaping concern applies here. Patterns from |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class XTunerBaseModelConfig(PydanticBaseModel): | ||||||||||||||||||||||||||||
|
|
@@ -313,6 +319,7 @@ def fully_shard( | |||||||||||||||||||||||||||
| """Fully shard the model parameters.""" | ||||||||||||||||||||||||||||
| self.fsdp_config = fsdp_config | ||||||||||||||||||||||||||||
| self.fsdp_mesh = self._init_world_mesh() | ||||||||||||||||||||||||||||
| self._world_mesh = self.fsdp_mesh | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if self.fsdp_config.requires_grad: | ||||||||||||||||||||||||||||
| for name, module in self.named_modules(): | ||||||||||||||||||||||||||||
|
|
@@ -337,15 +344,79 @@ def fully_shard( | |||||||||||||||||||||||||||
| mp_policy = MixedPrecisionPolicy( | ||||||||||||||||||||||||||||
| param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| fully_shard( | ||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||
| self._fully_shard( | ||||||||||||||||||||||||||||
| mesh=self.fsdp_mesh, | ||||||||||||||||||||||||||||
| mp_policy=mp_policy, | ||||||||||||||||||||||||||||
| reshard_after_forward=fsdp_config.reshard_after_forward, | ||||||||||||||||||||||||||||
| offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return self | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _fully_shard( | ||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||
| mesh: DeviceMesh, | ||||||||||||||||||||||||||||
| mp_policy: MixedPrecisionPolicy, | ||||||||||||||||||||||||||||
| reshard_after_forward: bool, | ||||||||||||||||||||||||||||
| offload_policy: CPUOffloadPolicy | None, | ||||||||||||||||||||||||||||
| module: nn.Module | None = None, | ||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||
| def traverse(module): | ||||||||||||||||||||||||||||
| for name, param in module.named_parameters(recurse=False): | ||||||||||||||||||||||||||||
| full_name = full_param_name_mapping[id(param)] | ||||||||||||||||||||||||||||
| full_name = self._clean_param_name(full_name) | ||||||||||||||||||||||||||||
| hf_name_list = self.to_hf_key_list(full_name) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for hf_name in hf_name_list: | ||||||||||||||||||||||||||||
| if any(re.search(p, hf_name) for p in patterns): # type: ignore | ||||||||||||||||||||||||||||
| if not isinstance(param, DTensor): | ||||||||||||||||||||||||||||
|
Comment on lines
+364
to
+371
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning — In Downstream in Consider guarding the yield:
Suggested change
The same pattern should be applied in |
||||||||||||||||||||||||||||
| dist_param = nn.Parameter( | ||||||||||||||||||||||||||||
| distribute_tensor( | ||||||||||||||||||||||||||||
| param, self.world_mesh, [Replicate() for _ in range(self.world_mesh.ndim)] | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| requires_grad=param.requires_grad, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| module.register_parameter(name, dist_param) | ||||||||||||||||||||||||||||
| ignored_params.add(dist_param) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| # param is already a DTensor (e.g. distributed by | ||||||||||||||||||||||||||||
| # MoE._replicate_other_params on ep_mesh before _fully_shard | ||||||||||||||||||||||||||||
| # is called). We skip re-distributing on world_mesh and just | ||||||||||||||||||||||||||||
| # add it to ignored_params so FSDP leaves it alone. | ||||||||||||||||||||||||||||
| # ASSUMPTION: fp32 distribution always happens AFTER any | ||||||||||||||||||||||||||||
| # prior EP distribution, so the existing placement is correct. | ||||||||||||||||||||||||||||
| ignored_params.add(param) | ||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||
|
Comment on lines
+372
to
+388
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: The One subtle note: when a param is already a DTensor (e.g., after |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for child in module.children(): | ||||||||||||||||||||||||||||
| traverse(child) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Collect the parameters of `target` that match any fp32 pattern so they can be | ||||||||||||||||||||||||||||
| # excluded from FSDP sharding (passed as `ignored_params`). | ||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||
| # We intentionally iterate over `self.named_parameters()` rather than | ||||||||||||||||||||||||||||
| # `target.named_parameters()` so that `name` is always relative to the root model | ||||||||||||||||||||||||||||
| # (`self`). This matters when `target` is a sub-module (e.g. `self.embed_tokens`): | ||||||||||||||||||||||||||||
| # `target.named_parameters()` would yield bare names like `"weight"`, which | ||||||||||||||||||||||||||||
| # `to_hf_key_list` cannot resolve correctly. By iterating from `self` we get the | ||||||||||||||||||||||||||||
| # full path (e.g. `"embed_tokens.weight"`) and filter to `target`'s parameters | ||||||||||||||||||||||||||||
| # using identity comparison. | ||||||||||||||||||||||||||||
| full_param_name_mapping = {id(param): name for name, param in self.named_parameters()} | ||||||||||||||||||||||||||||
| ignored_params: set[nn.Parameter] = set() | ||||||||||||||||||||||||||||
| patterns = self.config.hf_save_cfg.fp32_keys_pattern | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| target = module or self | ||||||||||||||||||||||||||||
| if patterns: | ||||||||||||||||||||||||||||
| traverse(target) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| fully_shard( | ||||||||||||||||||||||||||||
| target, | ||||||||||||||||||||||||||||
| mesh=mesh, | ||||||||||||||||||||||||||||
| mp_policy=mp_policy, | ||||||||||||||||||||||||||||
| reshard_after_forward=reshard_after_forward, | ||||||||||||||||||||||||||||
| offload_policy=offload_policy, | ||||||||||||||||||||||||||||
| ignored_params=ignored_params if ignored_params else None, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"): | ||||||||||||||||||||||||||||
| with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"): | ||||||||||||||||||||||||||||
| self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix) | ||||||||||||||||||||||||||||
|
|
@@ -396,6 +467,12 @@ def device(self) -> torch.device: | |||||||||||||||||||||||||||
| return torch.device("cpu") | ||||||||||||||||||||||||||||
| return torch.device(DEVICE) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||
| def world_mesh(self) -> DeviceMesh | None: | ||||||||||||||||||||||||||||
| if not hasattr(self, "_world_mesh"): | ||||||||||||||||||||||||||||
| self._world_mesh = self._init_world_mesh() | ||||||||||||||||||||||||||||
| return self._world_mesh | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||
| def default_compile_cfg(self) -> dict[str, TorchCompileOption]: | ||||||||||||||||||||||||||||
| return {} | ||||||||||||||||||||||||||||
|
|
@@ -670,6 +747,12 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat | |||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| return ret | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _get_save_dtype(self, name: str, dtype: torch.dtype) -> torch.dtype: | ||||||||||||||||||||||||||||
| patterns = self.config.hf_save_cfg.fp32_keys_pattern | ||||||||||||||||||||||||||||
| if patterns and any(re.search(p, name) for p in patterns): | ||||||||||||||||||||||||||||
| return torch.float32 | ||||||||||||||||||||||||||||
| return dtype | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _get_shard_hf_param( | ||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||
| params: list[tuple[torch.Tensor, LoadSpec]], | ||||||||||||||||||||||||||||
|
|
@@ -679,6 +762,16 @@ def _get_shard_hf_param( | |||||||||||||||||||||||||||
| ) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]: | ||||||||||||||||||||||||||||
| if not params: | ||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ignored_params, params = self._split_ignored_params(params) | ||||||||||||||||||||||||||||
| if ignored_params: | ||||||||||||||||||||||||||||
| name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] | ||||||||||||||||||||||||||||
| hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params] | ||||||||||||||||||||||||||||
| yield name_list, hf_params | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if not params: | ||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if dtype != torch.bfloat16: | ||||||||||||||||||||||||||||
| raise NotImplementedError | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -696,7 +789,7 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis | |||||||||||||||||||||||||||
| # Get unsharded params | ||||||||||||||||||||||||||||
| _unsharded_tensor_list = foreach_all_gather(fsdp_unsharded_tensor_list, load_spec0.group) | ||||||||||||||||||||||||||||
| unsharded_tensor_list = [ | ||||||||||||||||||||||||||||
| torch.cat([i.to(dtype) for i in tensors], dim=load_spec0.dim) for tensors in _unsharded_tensor_list | ||||||||||||||||||||||||||||
| torch.cat(list(tensors), dim=load_spec0.dim) for tensors in _unsharded_tensor_list | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| name_list = [spec.hf_keys[0] for _, spec in fsdp_tensor_list] | ||||||||||||||||||||||||||||
| unsharded_tensor_list = [ | ||||||||||||||||||||||||||||
|
|
@@ -711,11 +804,11 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| safetensor_size = 0 | ||||||||||||||||||||||||||||
| tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] | ||||||||||||||||||||||||||||
| name_list: list[str] = [] | ||||||||||||||||||||||||||||
| name_list = [] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for param, load_spec in params: | ||||||||||||||||||||||||||||
| local_tensor = param._local_tensor if isinstance(param, DTensor) else param | ||||||||||||||||||||||||||||
| local_tensor = local_tensor.to(dtype=dtype) | ||||||||||||||||||||||||||||
| local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16)) | ||||||||||||||||||||||||||||
| tensor_size = self._get_tensor_size(param, dtype) | ||||||||||||||||||||||||||||
| if safetensor_size + tensor_size > bucket_size and tensor_list: | ||||||||||||||||||||||||||||
| hf_params = _get_hf_params(tensor_list) | ||||||||||||||||||||||||||||
|
|
@@ -744,6 +837,12 @@ def _get_fused_hf_param( | |||||||||||||||||||||||||||
| if not params: | ||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ignored_params, params = self._split_ignored_params(params) | ||||||||||||||||||||||||||||
| if ignored_params: | ||||||||||||||||||||||||||||
| fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] | ||||||||||||||||||||||||||||
| fp32_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params] | ||||||||||||||||||||||||||||
| yield fp32_name_list, fp32_params | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _get_hf_params( | ||||||||||||||||||||||||||||
| fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]], | ||||||||||||||||||||||||||||
| name_list: list[str], | ||||||||||||||||||||||||||||
|
|
@@ -867,7 +966,7 @@ def _get_hf_params( | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for param, load_spec in params: | ||||||||||||||||||||||||||||
| local_tensor = param._local_tensor if isinstance(param, DTensor) else param | ||||||||||||||||||||||||||||
| local_tensor = local_tensor.bfloat16() | ||||||||||||||||||||||||||||
| local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16)) | ||||||||||||||||||||||||||||
| tensor_size = self._get_tensor_size(param, dtype) | ||||||||||||||||||||||||||||
| if safetensor_size + tensor_size > bucket_size and tensor_list: | ||||||||||||||||||||||||||||
| hf_params, name_list = _get_hf_params(tensor_list, name_list) | ||||||||||||||||||||||||||||
|
|
@@ -893,6 +992,15 @@ def _get_same_hf_param( | |||||||||||||||||||||||||||
| ) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]: | ||||||||||||||||||||||||||||
| if not params: | ||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ignored_params, params = self._split_ignored_params(params) | ||||||||||||||||||||||||||||
| if ignored_params: | ||||||||||||||||||||||||||||
| fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params] | ||||||||||||||||||||||||||||
| fp32_tensor_list: list[torch.Tensor] = [ | ||||||||||||||||||||||||||||
| param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| yield fp32_name_list, fp32_tensor_list | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if bucket_size is None: | ||||||||||||||||||||||||||||
| bucket_size = self.config.hf_save_cfg.bucket_size | ||||||||||||||||||||||||||||
| safetensor_size = 0 | ||||||||||||||||||||||||||||
|
|
@@ -909,7 +1017,7 @@ def _get_same_hf_param( | |||||||||||||||||||||||||||
| buffer_name_list.append(load_spec.hf_keys[0]) | ||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||
| local_tensor = param._local_tensor if isinstance(param, DTensor) else param | ||||||||||||||||||||||||||||
| local_tensor = local_tensor.bfloat16() | ||||||||||||||||||||||||||||
| local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16)) | ||||||||||||||||||||||||||||
| tensor_size = self._get_tensor_size(param, dtype) | ||||||||||||||||||||||||||||
| if safetensor_size + tensor_size > bucket_size and tensor_list: | ||||||||||||||||||||||||||||
| if self.fsdp_mesh is not None: | ||||||||||||||||||||||||||||
|
|
@@ -953,6 +1061,21 @@ def _get_same_hf_param( | |||||||||||||||||||||||||||
| if buffer_tensor_list: | ||||||||||||||||||||||||||||
| yield buffer_name_list, buffer_tensor_list | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _is_ignored_params(self, key: str): | ||||||||||||||||||||||||||||
| patterns = self.config.hf_save_cfg.fp32_keys_pattern | ||||||||||||||||||||||||||||
| if patterns is None: | ||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||
| return any(re.search(p, key) for p in patterns) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _split_ignored_params( | ||||||||||||||||||||||||||||
| self, params: list[tuple[torch.Tensor, LoadSpec]] | ||||||||||||||||||||||||||||
| ) -> tuple[list[tuple[torch.Tensor, LoadSpec]], list[tuple[torch.Tensor, LoadSpec]]]: | ||||||||||||||||||||||||||||
| if not self.config.hf_save_cfg.fp32_keys_pattern: | ||||||||||||||||||||||||||||
| return [], params | ||||||||||||||||||||||||||||
| ignored_params = [(p, l) for p, l in params if self._is_ignored_params(l.hf_keys[0])] | ||||||||||||||||||||||||||||
| remaining = [(p, l) for p, l in params if not self._is_ignored_params(l.hf_keys[0])] | ||||||||||||||||||||||||||||
| return ignored_params, remaining | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # TODO: Using `xtuenr.v1.utils.misc.clean_param_name` | ||||||||||||||||||||||||||||
| def _clean_param_name(self, name: str) -> str: | ||||||||||||||||||||||||||||
| if "_checkpoint_wrapped_module." in name: | ||||||||||||||||||||||||||||
|
|
@@ -1230,7 +1353,12 @@ def _load_same_hf_param( | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| loaded_tensor = loaded_tensor.to(local_tensor.device) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if self.fsdp_mesh is not None and isinstance(param, nn.Parameter): | ||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||
| self.fsdp_mesh is not None | ||||||||||||||||||||||||||||
| and isinstance(param, nn.Parameter) | ||||||||||||||||||||||||||||
| and isinstance(param, DTensor) | ||||||||||||||||||||||||||||
| and any(isinstance(p, Shard) for p in param.placements) | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| shape_before_fsdp = load_spec.shape | ||||||||||||||||||||||||||||
| _, _offset = compute_local_shape_and_global_offset( | ||||||||||||||||||||||||||||
| shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Two issues with
test_save_hf_with_mtp:Warning — Test does not verify fp32 dtype preservation. The test checks
torch.equal(origin_tensor, saved_tensor)for value equality, but never asserts that parameters matchingfp32_keys_pattern(e.g.,linear_attn.norm.weight,linear_attn.A_log) are saved intorch.float32dtype. This is the core feature of the PR — consider adding dtype assertions.Nit —
cache_save_fhdict uses safetensor filenames as keys. Both origin and saved dirs may produce the same filenames, causing key collisions. Use full file paths as dict keys instead.