diff --git a/xtuner/v1/config/optim.py b/xtuner/v1/config/optim.py index 5827edd8c..dd913d235 100644 --- a/xtuner/v1/config/optim.py +++ b/xtuner/v1/config/optim.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict from typing_extensions import Annotated -from xtuner.v1.optim import Muon +from xtuner.v1.optim import Muon, SwapAdamW from xtuner.v1.utils import get_logger @@ -32,6 +32,7 @@ class AdamWConfig(OptimConfig): betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for Adam optimizer")] = (0.9, 0.95) eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8 foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None + swap_optimizer: Annotated[Optional[bool], Parameter(help="Swap optimizer states to host memory.")] = False def build(self, model): params = [p for p in model.parameters() if p.requires_grad] @@ -52,6 +53,15 @@ def build(self, model): f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" ) logger.info(f"Untrainable parameters names: {untrainable_names}") + if self.swap_optimizer: + return SwapAdamW( + params, + lr=self.lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + foreach=self.foreach, + ) return torch.optim.AdamW( params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach ) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 2a576c59b..9637220f4 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -317,11 +317,19 @@ def save_dcp( with profile_time_and_memory(f"[DCP Checkpoint to {optimizer_dir}]"): if optimizer_dir is not None: - shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options) - dcp.save( - shard_optimizer_state_dict, - checkpoint_id=optimizer_dir, - ) + prepare_for_checkpoint_save = getattr(self.optimizer, "prepare_for_checkpoint_save", None) + finalize_after_checkpoint_save = getattr(self.optimizer, "finalize_after_checkpoint_save", None) + if callable(prepare_for_checkpoint_save): + prepare_for_checkpoint_save() + try: + shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options) + dcp.save( + shard_optimizer_state_dict, + checkpoint_id=optimizer_dir, + ) + finally: + if callable(finalize_after_checkpoint_save): + finalize_after_checkpoint_save() def load_dcp( self, @@ -353,39 +361,47 @@ def load_dcp( if optimizer_dir is not None: with profile_time_and_memory(f"[Load DCP Optimizer] from {optimizer_dir}"): - shard_optimizer_state_dict = get_optimizer_state_dict( - self.model, self.optimizer, options=_load_options - ) - dcp.load( - state_dict=shard_optimizer_state_dict, - checkpoint_id=optimizer_dir, - ) - if not load_states: - logger.info("Not loading optimizer states") - shard_optimizer_state_dict["state"] = {} - if not load_args: - logger.info("Not loading arg defaults") - param_groups = self.optimizer.state_dict()["param_groups"] - # Now we only support one param_group. If we want to support different lr for different parameters, - # we may use multiple param_groups like: - # [{'params': ['net1.weight', 'net2.weight'], 'lr': 0.001}, {'params': ['net3.weight'], 'lr': 0.002}] - # Then we need change the code here - assert len(param_groups) == 1, "Only one param_group is supported now" - init_defaults = param_groups[0] - init_defaults.pop("params") - for param_group in cast(List[Dict[str, Any]], shard_optimizer_state_dict["param_groups"]): - # param_group is like: {'params': ['net1.weight', 'net2.weight'], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01} - default_keys = list(filter(lambda x: x != "params", param_group.keys())) - for key in default_keys: - param_group.pop(key) - param_group.update(init_defaults) # lr, betas, eps, etc. - - set_optimizer_state_dict( - self.model, - self.optimizer, - optim_state_dict=shard_optimizer_state_dict, - options=_set_options, - ) + prepare_for_checkpoint_load = getattr(self.optimizer, "prepare_for_checkpoint_load", None) + finalize_after_checkpoint_load = getattr(self.optimizer, "finalize_after_checkpoint_load", None) + if callable(prepare_for_checkpoint_load): + prepare_for_checkpoint_load() + try: + shard_optimizer_state_dict = get_optimizer_state_dict( + self.model, self.optimizer, options=_load_options + ) + dcp.load( + state_dict=shard_optimizer_state_dict, + checkpoint_id=optimizer_dir, + ) + if not load_states: + logger.info("Not loading optimizer states") + shard_optimizer_state_dict["state"] = {} + if not load_args: + logger.info("Not loading arg defaults") + param_groups = self.optimizer.state_dict()["param_groups"] + # Now we only support one param_group. If we want to support different lr for different parameters, + # we may use multiple param_groups like: + # [{'params': ['net1.weight', 'net2.weight'], 'lr': 0.001}, {'params': ['net3.weight'], 'lr': 0.002}] + # Then we need change the code here + assert len(param_groups) == 1, "Only one param_group is supported now" + init_defaults = param_groups[0] + init_defaults.pop("params") + for param_group in cast(List[Dict[str, Any]], shard_optimizer_state_dict["param_groups"]): + # param_group is like: {'params': ['net1.weight', 'net2.weight'], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01} + default_keys = list(filter(lambda x: x != "params", param_group.keys())) + for key in default_keys: + param_group.pop(key) + param_group.update(init_defaults) # lr, betas, eps, etc. + + set_optimizer_state_dict( + self.model, + self.optimizer, + optim_state_dict=shard_optimizer_state_dict, + options=_set_options, + ) + finally: + if callable(finalize_after_checkpoint_load): + finalize_after_checkpoint_load() def put_model_to_device(self, device: torch.device | str): """Put the model to the given device.""" @@ -394,7 +410,7 @@ def put_model_to_device(self, device: torch.device | str): def put_optimizer_to_device(self, device: torch.device | str): """Put the optimizer to the given device.""" - if self.fsdp_cfg.cpu_offload: + if self.fsdp_cfg.cpu_offload or self.optim_cfg.swap_optimizer: return if not self.optimizer.state: return diff --git a/xtuner/v1/optim/__init__.py b/xtuner/v1/optim/__init__.py index 0ea5fdeb0..c220f2996 100644 --- a/xtuner/v1/optim/__init__.py +++ b/xtuner/v1/optim/__init__.py @@ -1,4 +1,5 @@ from .muon import Muon # type: ignore +from .swap_adamw import SwapAdamW -__all__ = ["Muon"] +__all__ = ["Muon", "SwapAdamW"] diff --git a/xtuner/v1/optim/swap_adamw.py b/xtuner/v1/optim/swap_adamw.py new file mode 100644 index 000000000..b2d444a5d --- /dev/null +++ b/xtuner/v1/optim/swap_adamw.py @@ -0,0 +1,237 @@ +from collections.abc import Iterable + +import torch +import torch.distributed as dist + +from xtuner.v1.utils import get_device, get_logger, get_torch_device_module + + +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() +logger = get_logger() + + +class SwapAdamW(torch.optim.AdamW): + """AdamW optimizer with optimizer-state swap between device and host. + + Optimizer states are initialized once, mirrored to pinned host tensors, and can be swapped to device only when + needed (e.g. during optimizer step). + """ + + _state_keys = ("exp_avg", "exp_avg_sq", "max_exp_avg_sq") + + def __init__( + self, + params: Iterable[torch.Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + *, + maximize: bool = False, + foreach: bool | None = None, + capturable: bool = False, + differentiable: bool = False, + fused: bool | None = None, + swap_optimizer_times: int = 16, + ): + super().__init__( + params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + ) + self._swap_optimizer_times = swap_optimizer_times + self._param_to_group_map: dict[torch.Tensor, dict] = {} + self._swap_to_device_events_map: dict[torch.Tensor, torch.Event | None] = {} + self._swap_to_host_events_map: dict[torch.Tensor, torch.Event | None] = {} + self._param_to_cpu_states_map: dict[torch.Tensor, dict[str, torch.Tensor | None]] = {} + self._param_to_device_states_map: dict[torch.Tensor, dict[str, torch.Tensor | None]] = {} + self._states_on_device = False + self._init_swap_states() + + @staticmethod + def _to_local_tensor(tensor: torch.Tensor) -> torch.Tensor: + if hasattr(tensor, "to_local"): + return tensor.to_local() # type: ignore[no-any-return] + return tensor + + def _init_swap_states(self) -> None: + for group in self.param_groups: + for param in group["params"]: + self._param_to_group_map[param] = group + + swap_num = sum(self._to_local_tensor(main_param).numel() for main_param in self._param_to_group_map) + self.swap_numel = swap_num // self._swap_optimizer_times + swap_memory = swap_num * 8 / 1024 / 1024 + logger.info( + f"[Rank {DEVICE_MODULE.current_device()}] swap optimizer param num: {swap_num}, " + f"param size: {swap_memory}MB\n", + end="", + ) + + for group in self.param_groups: + for param in group["params"]: + device_state_dtensor = self.state[param] + device_state_tensor: dict[str, torch.Tensor | None] = {} + cpu_state: dict[str, torch.Tensor | None] = {} + amsgrad = bool(self._param_to_group_map[param]["amsgrad"]) + + for key in self._state_keys: + if key == "max_exp_avg_sq" and not amsgrad: + device_state_dtensor[key] = None + device_state_tensor[key] = None + cpu_state[key] = None + else: + device_state_dtensor[key] = torch.zeros_like(param, memory_format=torch.preserve_format) + device_tensor = self._to_local_tensor(device_state_dtensor[key]) + cpu_tensor = torch.empty_like(device_tensor, pin_memory=True, device="cpu") + cpu_tensor.copy_(device_tensor, non_blocking=True) + device_tensor.storage().resize_(0) + device_state_tensor[key] = device_tensor + cpu_state[key] = cpu_tensor + + self._param_to_device_states_map[param] = device_state_tensor + self._param_to_cpu_states_map[param] = cpu_state + + DEVICE_MODULE.synchronize() + + def swap_all_to_host(self) -> None: + for param in self._param_to_cpu_states_map: + self._swap_tensors_to_host(param) + for param in self._param_to_cpu_states_map: + event = self._swap_to_host_events_map.get(param, None) + if event is not None: + DEVICE_MODULE.current_stream().wait_event(event) + self._swap_to_host_events_map[param] = None + self._states_on_device = False + + def swap_all_to_device(self) -> None: + dist.barrier(dist.group.WORLD) + for param in self._param_to_cpu_states_map: + self._swap_tensors_to_device(param) + for param in self._param_to_cpu_states_map: + event = self._swap_to_device_events_map.get(param, None) + if event is not None: + DEVICE_MODULE.current_stream().wait_event(event) + self._swap_to_device_events_map[param] = None + self._states_on_device = True + + def _ensure_states_on_device(self) -> None: + if self._states_on_device: + return + self.swap_all_to_device() + + def _ensure_states_on_host(self) -> None: + if not self._states_on_device: + return + self.swap_all_to_host() + + def prepare_for_checkpoint_save(self) -> None: + self._ensure_states_on_device() + + def finalize_after_checkpoint_save(self) -> None: + self._ensure_states_on_host() + + def prepare_for_checkpoint_load(self) -> None: + self._ensure_states_on_device() + + def finalize_after_checkpoint_load(self) -> None: + self._ensure_states_on_host() + + def _swap_tensors_to_device(self, param: torch.Tensor) -> None: + cpu_state = self._param_to_cpu_states_map[param] + device_state = self._param_to_device_states_map.get(param, None) + if device_state is None: + return + for key in self._state_keys: + device_tensor = device_state.get(key, None) + cpu_tensor = cpu_state.get(key, None) + if device_tensor is None or cpu_tensor is None: + continue + if device_tensor.storage().size() == 0: + device_tensor.storage().resize_(cpu_tensor.storage().size()) + device_tensor.copy_(cpu_tensor, non_blocking=True) + + self._swap_to_device_events_map[param] = DEVICE_MODULE.current_stream().record_event() + + def _swap_tensors_to_host(self, param: torch.Tensor) -> None: + cpu_state = self._param_to_cpu_states_map[param] + device_state = self._param_to_device_states_map.get(param, None) + if device_state is None: + return + for key in self._state_keys: + device_tensor = device_state.get(key, None) + cpu_tensor = cpu_state.get(key, None) + if device_tensor is None or cpu_tensor is None: + continue + if device_tensor.storage().size() != 0: + cpu_tensor.copy_(device_tensor, non_blocking=True) + device_tensor.storage().resize_(0) + + self._swap_to_host_events_map[param] = DEVICE_MODULE.current_stream().record_event() + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if "step" in group: + group["step"] += 1 + if group["step"].is_cpu: + group["step"] = group["step"].to(DEVICE) + else: + group["step"] = torch.tensor(1, dtype=torch.int64, device=DEVICE_MODULE.current_device()) + + params_list = list(self._param_to_group_map.keys()) + self._ensure_states_on_device() + + for param in params_list: + if param.grad is None: + continue + if param.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + group = self._param_to_group_map[param] + amsgrad = bool(group["amsgrad"]) + beta1, beta2 = group["betas"] + state = self.state[param] + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + max_exp_avg_sq = state.get("max_exp_avg_sq", None) + assert isinstance(exp_avg, torch.Tensor) + assert isinstance(exp_avg_sq, torch.Tensor) + if amsgrad: + assert isinstance(max_exp_avg_sq, torch.Tensor) + + torch._fused_adamw_( + [self._to_local_tensor(param)], + [self._to_local_tensor(param.grad)], + [self._to_local_tensor(exp_avg)], + [self._to_local_tensor(exp_avg_sq)], + [self._to_local_tensor(max_exp_avg_sq)] if amsgrad else [], + [group["step"]], + amsgrad=amsgrad, + lr=group["lr"], + beta1=beta1, + beta2=beta2, + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + ) + + self._ensure_states_on_host() + DEVICE_MODULE.synchronize() + return loss