Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion xtuner/v1/config/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated

from xtuner.v1.optim import Muon
from xtuner.v1.optim import Muon, SwapAdamW
from xtuner.v1.utils import get_logger


Expand All @@ -32,6 +32,7 @@ class AdamWConfig(OptimConfig):
betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for Adam optimizer")] = (0.9, 0.95)
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8
foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None
swap_optimizer: Annotated[Optional[bool], Parameter(help="Swap optimizer states to host memory.")] = False

def build(self, model):
params = [p for p in model.parameters() if p.requires_grad]
Expand All @@ -52,6 +53,15 @@ def build(self, model):
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Untrainable parameters names: {untrainable_names}")
if self.swap_optimizer:
return SwapAdamW(
params,
lr=self.lr,
betas=self.betas,
eps=self.eps,
weight_decay=self.weight_decay,
foreach=self.foreach,
)
return torch.optim.AdamW(
params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach
)
Expand Down
94 changes: 55 additions & 39 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,19 @@ def save_dcp(

with profile_time_and_memory(f"[DCP Checkpoint to {optimizer_dir}]"):
if optimizer_dir is not None:
shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options)
dcp.save(
shard_optimizer_state_dict,
checkpoint_id=optimizer_dir,
)
prepare_for_checkpoint_save = getattr(self.optimizer, "prepare_for_checkpoint_save", None)
finalize_after_checkpoint_save = getattr(self.optimizer, "finalize_after_checkpoint_save", None)
if callable(prepare_for_checkpoint_save):
prepare_for_checkpoint_save()
try:
shard_optimizer_state_dict = get_optimizer_state_dict(self.model, self.optimizer, options=_options)
dcp.save(
shard_optimizer_state_dict,
checkpoint_id=optimizer_dir,
)
finally:
if callable(finalize_after_checkpoint_save):
finalize_after_checkpoint_save()

def load_dcp(
self,
Expand Down Expand Up @@ -353,39 +361,47 @@ def load_dcp(

if optimizer_dir is not None:
with profile_time_and_memory(f"[Load DCP Optimizer] from {optimizer_dir}"):
shard_optimizer_state_dict = get_optimizer_state_dict(
self.model, self.optimizer, options=_load_options
)
dcp.load(
state_dict=shard_optimizer_state_dict,
checkpoint_id=optimizer_dir,
)
if not load_states:
logger.info("Not loading optimizer states")
shard_optimizer_state_dict["state"] = {}
if not load_args:
logger.info("Not loading arg defaults")
param_groups = self.optimizer.state_dict()["param_groups"]
# Now we only support one param_group. If we want to support different lr for different parameters,
# we may use multiple param_groups like:
# [{'params': ['net1.weight', 'net2.weight'], 'lr': 0.001}, {'params': ['net3.weight'], 'lr': 0.002}]
# Then we need change the code here
assert len(param_groups) == 1, "Only one param_group is supported now"
init_defaults = param_groups[0]
init_defaults.pop("params")
for param_group in cast(List[Dict[str, Any]], shard_optimizer_state_dict["param_groups"]):
# param_group is like: {'params': ['net1.weight', 'net2.weight'], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01}
default_keys = list(filter(lambda x: x != "params", param_group.keys()))
for key in default_keys:
param_group.pop(key)
param_group.update(init_defaults) # lr, betas, eps, etc.

set_optimizer_state_dict(
self.model,
self.optimizer,
optim_state_dict=shard_optimizer_state_dict,
options=_set_options,
)
prepare_for_checkpoint_load = getattr(self.optimizer, "prepare_for_checkpoint_load", None)
finalize_after_checkpoint_load = getattr(self.optimizer, "finalize_after_checkpoint_load", None)
if callable(prepare_for_checkpoint_load):
prepare_for_checkpoint_load()
try:
shard_optimizer_state_dict = get_optimizer_state_dict(
self.model, self.optimizer, options=_load_options
)
dcp.load(
state_dict=shard_optimizer_state_dict,
checkpoint_id=optimizer_dir,
)
if not load_states:
logger.info("Not loading optimizer states")
shard_optimizer_state_dict["state"] = {}
if not load_args:
logger.info("Not loading arg defaults")
param_groups = self.optimizer.state_dict()["param_groups"]
# Now we only support one param_group. If we want to support different lr for different parameters,
# we may use multiple param_groups like:
# [{'params': ['net1.weight', 'net2.weight'], 'lr': 0.001}, {'params': ['net3.weight'], 'lr': 0.002}]
# Then we need change the code here
assert len(param_groups) == 1, "Only one param_group is supported now"
init_defaults = param_groups[0]
init_defaults.pop("params")
for param_group in cast(List[Dict[str, Any]], shard_optimizer_state_dict["param_groups"]):
# param_group is like: {'params': ['net1.weight', 'net2.weight'], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01}
default_keys = list(filter(lambda x: x != "params", param_group.keys()))
for key in default_keys:
param_group.pop(key)
param_group.update(init_defaults) # lr, betas, eps, etc.

set_optimizer_state_dict(
self.model,
self.optimizer,
optim_state_dict=shard_optimizer_state_dict,
options=_set_options,
)
finally:
if callable(finalize_after_checkpoint_load):
finalize_after_checkpoint_load()

def put_model_to_device(self, device: torch.device | str):
"""Put the model to the given device."""
Expand All @@ -394,7 +410,7 @@ def put_model_to_device(self, device: torch.device | str):

def put_optimizer_to_device(self, device: torch.device | str):
"""Put the optimizer to the given device."""
if self.fsdp_cfg.cpu_offload:
if self.fsdp_cfg.cpu_offload or self.optim_cfg.swap_optimizer:
return
if not self.optimizer.state:
return
Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .muon import Muon # type: ignore
from .swap_adamw import SwapAdamW


__all__ = ["Muon"]
__all__ = ["Muon", "SwapAdamW"]
237 changes: 237 additions & 0 deletions xtuner/v1/optim/swap_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from collections.abc import Iterable

import torch
import torch.distributed as dist

from xtuner.v1.utils import get_device, get_logger, get_torch_device_module


DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()
logger = get_logger()


class SwapAdamW(torch.optim.AdamW):
"""AdamW optimizer with optimizer-state swap between device and host.

Optimizer states are initialized once, mirrored to pinned host tensors, and can be swapped to device only when
needed (e.g. during optimizer step).
"""

_state_keys = ("exp_avg", "exp_avg_sq", "max_exp_avg_sq")

def __init__(
self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
*,
maximize: bool = False,
foreach: bool | None = None,
capturable: bool = False,
differentiable: bool = False,
fused: bool | None = None,
swap_optimizer_times: int = 16,
):
super().__init__(
params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
fused=fused,
)
self._swap_optimizer_times = swap_optimizer_times
self._param_to_group_map: dict[torch.Tensor, dict] = {}
self._swap_to_device_events_map: dict[torch.Tensor, torch.Event | None] = {}
self._swap_to_host_events_map: dict[torch.Tensor, torch.Event | None] = {}
self._param_to_cpu_states_map: dict[torch.Tensor, dict[str, torch.Tensor | None]] = {}
self._param_to_device_states_map: dict[torch.Tensor, dict[str, torch.Tensor | None]] = {}
self._states_on_device = False
self._init_swap_states()

@staticmethod
def _to_local_tensor(tensor: torch.Tensor) -> torch.Tensor:
if hasattr(tensor, "to_local"):
return tensor.to_local() # type: ignore[no-any-return]
return tensor

def _init_swap_states(self) -> None:
for group in self.param_groups:
for param in group["params"]:
self._param_to_group_map[param] = group

swap_num = sum(self._to_local_tensor(main_param).numel() for main_param in self._param_to_group_map)
self.swap_numel = swap_num // self._swap_optimizer_times
swap_memory = swap_num * 8 / 1024 / 1024
logger.info(
f"[Rank {DEVICE_MODULE.current_device()}] swap optimizer param num: {swap_num}, "
f"param size: {swap_memory}MB\n",
end="",
)

for group in self.param_groups:
for param in group["params"]:
device_state_dtensor = self.state[param]
device_state_tensor: dict[str, torch.Tensor | None] = {}
cpu_state: dict[str, torch.Tensor | None] = {}
amsgrad = bool(self._param_to_group_map[param]["amsgrad"])

for key in self._state_keys:
if key == "max_exp_avg_sq" and not amsgrad:
device_state_dtensor[key] = None
device_state_tensor[key] = None
cpu_state[key] = None
else:
device_state_dtensor[key] = torch.zeros_like(param, memory_format=torch.preserve_format)
device_tensor = self._to_local_tensor(device_state_dtensor[key])
cpu_tensor = torch.empty_like(device_tensor, pin_memory=True, device="cpu")
cpu_tensor.copy_(device_tensor, non_blocking=True)
device_tensor.storage().resize_(0)
device_state_tensor[key] = device_tensor
cpu_state[key] = cpu_tensor

self._param_to_device_states_map[param] = device_state_tensor
self._param_to_cpu_states_map[param] = cpu_state

DEVICE_MODULE.synchronize()

def swap_all_to_host(self) -> None:
for param in self._param_to_cpu_states_map:
self._swap_tensors_to_host(param)
for param in self._param_to_cpu_states_map:
event = self._swap_to_host_events_map.get(param, None)
if event is not None:
DEVICE_MODULE.current_stream().wait_event(event)
self._swap_to_host_events_map[param] = None
self._states_on_device = False

def swap_all_to_device(self) -> None:
dist.barrier(dist.group.WORLD)
for param in self._param_to_cpu_states_map:
self._swap_tensors_to_device(param)
for param in self._param_to_cpu_states_map:
event = self._swap_to_device_events_map.get(param, None)
if event is not None:
DEVICE_MODULE.current_stream().wait_event(event)
self._swap_to_device_events_map[param] = None
self._states_on_device = True

def _ensure_states_on_device(self) -> None:
if self._states_on_device:
return
self.swap_all_to_device()

def _ensure_states_on_host(self) -> None:
if not self._states_on_device:
return
self.swap_all_to_host()

def prepare_for_checkpoint_save(self) -> None:
self._ensure_states_on_device()

def finalize_after_checkpoint_save(self) -> None:
self._ensure_states_on_host()

def prepare_for_checkpoint_load(self) -> None:
self._ensure_states_on_device()

def finalize_after_checkpoint_load(self) -> None:
self._ensure_states_on_host()

def _swap_tensors_to_device(self, param: torch.Tensor) -> None:
cpu_state = self._param_to_cpu_states_map[param]
device_state = self._param_to_device_states_map.get(param, None)
if device_state is None:
return
for key in self._state_keys:
device_tensor = device_state.get(key, None)
cpu_tensor = cpu_state.get(key, None)
if device_tensor is None or cpu_tensor is None:
continue
if device_tensor.storage().size() == 0:
device_tensor.storage().resize_(cpu_tensor.storage().size())
device_tensor.copy_(cpu_tensor, non_blocking=True)

self._swap_to_device_events_map[param] = DEVICE_MODULE.current_stream().record_event()

def _swap_tensors_to_host(self, param: torch.Tensor) -> None:
cpu_state = self._param_to_cpu_states_map[param]
device_state = self._param_to_device_states_map.get(param, None)
if device_state is None:
return
for key in self._state_keys:
device_tensor = device_state.get(key, None)
cpu_tensor = cpu_state.get(key, None)
if device_tensor is None or cpu_tensor is None:
continue
if device_tensor.storage().size() != 0:
cpu_tensor.copy_(device_tensor, non_blocking=True)
device_tensor.storage().resize_(0)

self._swap_to_host_events_map[param] = DEVICE_MODULE.current_stream().record_event()

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if "step" in group:
group["step"] += 1
if group["step"].is_cpu:
group["step"] = group["step"].to(DEVICE)
else:
group["step"] = torch.tensor(1, dtype=torch.int64, device=DEVICE_MODULE.current_device())

params_list = list(self._param_to_group_map.keys())
self._ensure_states_on_device()

for param in params_list:
if param.grad is None:
continue
if param.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")

group = self._param_to_group_map[param]
amsgrad = bool(group["amsgrad"])
beta1, beta2 = group["betas"]
state = self.state[param]

exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
max_exp_avg_sq = state.get("max_exp_avg_sq", None)
assert isinstance(exp_avg, torch.Tensor)
assert isinstance(exp_avg_sq, torch.Tensor)
if amsgrad:
assert isinstance(max_exp_avg_sq, torch.Tensor)

torch._fused_adamw_(
[self._to_local_tensor(param)],
[self._to_local_tensor(param.grad)],
[self._to_local_tensor(exp_avg)],
[self._to_local_tensor(exp_avg_sq)],
[self._to_local_tensor(max_exp_avg_sq)] if amsgrad else [],
[group["step"]],
amsgrad=amsgrad,
lr=group["lr"],
beta1=beta1,
beta2=beta2,
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
)

self._ensure_states_on_host()
DEVICE_MODULE.synchronize()
return loss
Loading