Skip to content

Commit e3f5001

Browse files
feat(comm/attn_offload.py): support selective ckpt and cpu offload (#383)
1 parent 141e9eb commit e3f5001

File tree

8 files changed

+549
-11
lines changed

8 files changed

+549
-11
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .attn_offload import get_offload_manager, initialize_offload_manager
2+
3+
__all__ = ["initialize_offload_manager", "get_offload_manager"]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import torch
2+
3+
from internlm.utils.common import get_current_device
4+
5+
global_attn_offload = None
6+
7+
8+
class AttnOffloadManager:
9+
"""
10+
A manager for attention output CPU offloading and GPU prefetch loading.
11+
"""
12+
13+
def __init__(self, enable_cpu_offload: bool = False) -> None:
14+
# cpu offload overlapping
15+
self.cpu_offload = enable_cpu_offload
16+
# layer id mapping to flash attn output
17+
self.fa_output_mapping = {}
18+
self.fa_stream = torch.cuda.Stream()
19+
self.d2h_final_event = torch.cuda.Event()
20+
self.h2d_final_event = torch.cuda.Event()
21+
# prepare for tensor buffer
22+
self.tensor_id_to_tensor_bufs = {}
23+
24+
def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id):
25+
"""Get tensor buffer for offloaded tensor."""
26+
layer_id = layer_id % 2
27+
if layer_id not in self.tensor_id_to_tensor_bufs:
28+
self.tensor_id_to_tensor_bufs[layer_id] = {}
29+
30+
if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]:
31+
allocate_new_buf = True
32+
else:
33+
tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id]
34+
allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype
35+
36+
if allocate_new_buf:
37+
# supposed to only execute once
38+
buffer = torch.empty(
39+
tensor.size(),
40+
dtype=tensor.dtype,
41+
layout=tensor.layout,
42+
device=tensor.device,
43+
)
44+
45+
self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer
46+
47+
return self.tensor_id_to_tensor_bufs[layer_id][tensor_id]
48+
49+
def insert_fa_output_with_layer(self, layer_idx, output):
50+
assert layer_idx not in self.fa_output_mapping
51+
if self.cpu_offload is False:
52+
self.fa_output_mapping[layer_idx] = output
53+
return
54+
55+
tensors = []
56+
for tensor_id, tensor in enumerate(output):
57+
if tensor is None:
58+
tensors.append(None)
59+
continue
60+
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id)
61+
tensor_buf.copy_(tensor)
62+
tensors.append(tensor_buf)
63+
self.fa_output_mapping[layer_idx] = tensors
64+
65+
def get_fa_output_with_layer(self, layer_idx):
66+
assert layer_idx in self.fa_output_mapping
67+
return self.fa_output_mapping.pop(layer_idx)
68+
69+
def offload_fa_output_with_layer(self, layer_idx):
70+
assert layer_idx in self.fa_output_mapping
71+
72+
self.fa_stream.wait_stream(torch.cuda.current_stream())
73+
self.fa_stream.wait_event(self.d2h_final_event)
74+
75+
with torch.cuda.stream(self.fa_stream):
76+
_gpu_tensors = self.fa_output_mapping.pop(layer_idx)
77+
_cpu_tensors = []
78+
for _tensor in _gpu_tensors:
79+
if _tensor is None:
80+
_cpu_tensors.append(_tensor)
81+
continue
82+
83+
_cpu_backup = torch.empty(
84+
_tensor.size(),
85+
dtype=_tensor.dtype,
86+
layout=_tensor.layout,
87+
device="cpu",
88+
pin_memory=True,
89+
)
90+
_cpu_backup.copy_(_tensor, non_blocking=True)
91+
_cpu_tensors.append(_cpu_backup)
92+
93+
# _cpu_tensors.append(_tensor.to("cpu", non_blocking=False))
94+
95+
self.fa_output_mapping[layer_idx] = _cpu_tensors
96+
97+
self.fa_stream.record_event(self.d2h_final_event)
98+
99+
def preload_fa_output_with_layer(self, layer_idx):
100+
assert layer_idx in self.fa_output_mapping
101+
102+
self.fa_stream.wait_stream(torch.cuda.current_stream())
103+
self.fa_stream.wait_event(self.h2d_final_event)
104+
105+
# Important: get device before with stream, in stream get device is error
106+
_device = get_current_device()
107+
with torch.cuda.stream(self.fa_stream):
108+
_cpu_tensors = self.fa_output_mapping.pop(layer_idx)
109+
self.fa_output_mapping[layer_idx] = [
110+
_tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor
111+
for _tensor in _cpu_tensors
112+
]
113+
114+
self.fa_stream.record_event(self.h2d_final_event)
115+
116+
117+
def initialize_offload_manager(enable_cpu_offload: bool = False):
118+
global global_attn_offload
119+
if global_attn_offload is None:
120+
global_attn_offload = AttnOffloadManager(enable_cpu_offload)
121+
122+
return global_attn_offload
123+
124+
125+
def get_offload_manager():
126+
assert global_attn_offload is not None
127+
return global_attn_offload

internlm/core/parallel/comm/isp.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
params_dispatch_with_condition,
3838
)
3939

40+
from .attn_offload import get_offload_manager
41+
4042

4143
# not really useful, only for code hint.
4244
class WPCommunicator(ABC):
@@ -306,6 +308,7 @@ def __init__(
306308
overlap: bool = False,
307309
process_group: dist.ProcessGroup = None,
308310
is_moe: bool = False,
311+
selective_ckpt_offload: bool = False,
309312
) -> None:
310313
self.process_group = process_group
311314
self.overlap = overlap
@@ -316,6 +319,14 @@ def __init__(
316319
self._forward_prefetch_prerequisites = []
317320
self._forward_overlap_per = self._get_forward_overlap_granularity()
318321
self._launch_before_module = self._get_launch_before_module()
322+
# As an optimization, do not release weight after forward for the last
323+
# transformer block since wp would prefetch it immediately
324+
self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1]
325+
self.layers_fa_not_release = [
326+
gpc.config.isp_num_layers - 1,
327+
int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1,
328+
]
329+
self.sc_offload = selective_ckpt_offload
319330

320331
# real overlap state for each chunk.
321332
self._overlap_states: Dict[int, ISPOverlapState] = {}
@@ -411,6 +422,7 @@ def is_allgather_launch_module(name, module):
411422
self._overlap_states[cid].index_to_isp_modules[idx].append(child)
412423

413424
setattr(child, "isp_name", name)
425+
setattr(child, "isp_layer_idx", idx)
414426

415427
full_name = f"{cid}.{idx}.{name}"
416428
setattr(
@@ -506,6 +518,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args)
506518
if block_index + 1 < self._num_blocks:
507519
self._all_gather_block_weight(block_index + 1)
508520

521+
# register offload and prefetch hook for selective ckpt with wo linear
522+
if self.sc_offload is True:
523+
# move current layer's attn output from GPU to CPU asynchronizely
524+
if (
525+
self.is_forward is True
526+
and gpc.config.selective_checkpoint
527+
and block_index not in self.layers_fa_not_release
528+
and block_index < self._ckpt_block_num
529+
):
530+
get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index)
531+
532+
# load previous layer's attn output from CPU to GPU asynchronizely
533+
if (
534+
self.is_forward is False
535+
and gpc.config.selective_checkpoint
536+
and (0 <= (block_index - 1) < self._ckpt_block_num)
537+
):
538+
get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1)
539+
509540
def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
510541
if module not in self._weight_global_handle:
511542
self._all_gather_module_weight(module)
@@ -539,6 +570,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
539570
self._all_gather_module_weight(next_module)
540571

541572
def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
573+
if int(module.isp_layer_idx) in self.layers_wp_not_release:
574+
# print(f"the layer {module.isp_layer_idx} after forward not clear weight")
575+
return
542576
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
543577
self._clear_handle(module)
544578
self._clear_weight(module)

internlm/core/trainer_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from internlm.checkpoint.checkpoint_manager import CheckpointManager
1212
from internlm.core.context import global_context as gpc
1313
from internlm.core.context.process_group_initializer import ParallelMode
14+
from internlm.core.parallel.comm import initialize_offload_manager
1415
from internlm.core.trainer import Trainer
1516
from internlm.data.streaming.utils import streaming_simple_resume
1617
from internlm.data.train_state import get_train_state
@@ -118,6 +119,9 @@ def __init__(
118119
# initialize isp communicator
119120
isp_communicator = initialize_parallel_communicator(model)
120121

122+
# initialize cpu offload manager for selective checkpoint
123+
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))
124+
121125
# initialize train state
122126
train_state = get_train_state(train_dl)
123127

internlm/initialize/launch.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,22 @@ def get_default_parser():
6666
def args_sanity_check():
6767
assert gpc.config is not None, "config is not load!"
6868

69+
gpc.is_forward = True
70+
6971
if "JOB_NAME" not in gpc.config:
7072
gpc.config._add_item("JOB_NAME", "AnonymousJob")
7173

7274
# the default model type is INTERNLM
7375
if "model_type" not in gpc.config:
7476
gpc.config._add_item("model_type", ModelType.INTERNLM.name)
7577

78+
if gpc.config.model_type == "InternLM3_M":
79+
# TODO: need check for isp overlap
80+
num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers
81+
else:
82+
num_layers = gpc.config.model.num_layers
83+
gpc.config.isp_num_layers = num_layers
84+
7685
if "use_apex_adam" not in gpc.config:
7786
gpc.config._add_item("use_apex_adam", False)
7887

@@ -388,17 +397,18 @@ def args_sanity_check():
388397
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
389398
if gpc.config.parallel["tensor"].get("mode", None) is None:
390399
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
391-
assert (
392-
gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0
393-
), "VOCAB_SIZE must be integer multiple of tensor parallel size"
394400
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
395401
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp"
396402
assert (
397403
torch.__version__ >= "2.1.0"
398404
), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}"
399-
assert (
400-
gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0
401-
), "VOCAB_SIZE must be integer multiple of wp size"
405+
406+
assert (
407+
gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0
408+
), "model.vocab_size must be integer multiple of weight parallel size"
409+
assert (
410+
gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0
411+
), "model.vocab_size must be integer multiple of tensor parallel size"
402412

403413
assert gpc.config.parallel["tensor"].get("mode", None) in [
404414
TensorParallelMode.mtp.name,
@@ -524,7 +534,20 @@ def args_sanity_check():
524534
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
525535

526536
if "selective_checkpoint" not in gpc.config:
527-
gpc.config._add_item("selective_checkpoint", False)
537+
gpc.config.selective_checkpoint = False
538+
if "selective_checkpoint_offload" not in gpc.config:
539+
gpc.config.selective_checkpoint_offload = False
540+
if gpc.config.selective_checkpoint is True:
541+
assert (
542+
gpc.config.parallel["tensor"]["mode"] == "isp"
543+
), "When using selective_checkpoint, tensor parallel mode must be isp"
544+
if gpc.config.selective_checkpoint_offload is True:
545+
assert (
546+
gpc.config.selective_checkpoint is True
547+
), "When using selective_checkpoint_offload, selective_checkpoint must be True"
548+
assert (
549+
gpc.config.parallel.weight.launch_allgather_before == "wo"
550+
), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module"
528551

529552
# moe not support overlap and zero1.5 for now
530553
if gpc.config.model.get("num_experts", 1) > 1:

0 commit comments

Comments
 (0)