Skip to content
Open
12 changes: 9 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def validate_args(args, defaults={}):

if args.ds_sequence_parallel_size > 1:
assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+"

if args.ds_sequence_parallel_overlap_comm:
assert args.split_qkv_linear, \
"ds_sequence_parallel_overlap_comm requires split_qkv_linear is True"
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size * \
args.ds_sequence_parallel_size
Expand Down Expand Up @@ -924,6 +926,9 @@ def _add_training_args(parser):
group.add_argument('--disable-moe-top2-2nd-expert-sampling', action='store_false',
help='Disable MoE top2 sampling of the 2nd expert. Instead of sampling, use argmax.',
dest='moe_top2_2nd_expert_sampling')
group.add_argument('--split-qkv-linear', action='store_true',
help='Separate linear computations for query, key, and value.',
dest='split_qkv_linear')
group.add_argument('--use-flash-attn', '--use-flash-attn-v1', dest='use_flash_attn_v1', action='store_true',
help='use first version FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
Expand Down Expand Up @@ -975,14 +980,15 @@ def _add_training_args(parser):
help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.')
group.add_argument('--force-ds-sequence-parallel', action='store_true',
help='use DeepSpeed sequence parallelism regardless of sequence parallel size.')

group.add_argument('--ds-sequence-parallel-overlap-comm', action='store_true',
help='overlap comm for ds-sequence-parallel',
dest='ds_sequence_parallel_overlap_comm')
group.add_argument('--ds-sequence-parallel-fpdt', action='store_true',
help='use DeepSpeed sequence parallelism with FPDT.')
group.add_argument('--ds-sequence-parallel-fpdt-chunk-size', type=int, default=65536,
help='Chunk size used in FPDT attention.')
group.add_argument('--ds-sequence-parallel-fpdt-offloading', action='store_true',
help='use DeepSpeed sequence parallelism FPDT with offloading.')

group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
Expand Down
34 changes: 24 additions & 10 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from megatron import get_args

from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore

from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -248,13 +249,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
async_grad_allreduce, sequence_parallel, bwd_stream=None):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel

ctx.bwd_stream = bwd_stream

if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
Expand Down Expand Up @@ -314,6 +316,7 @@ def backward(ctx, grad_output):
total_input = all_gather_buffer
else:
total_input = input

grad_input = grad_output.matmul(weight)

if ctx.sequence_parallel:
Expand Down Expand Up @@ -368,23 +371,30 @@ def backward(ctx, grad_output):
# grad_weight = None
# else:
# grad_weight = grad_output.t().matmul(total_input)
if args.enable_zbh1_pipeline:
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore

if ctx.bwd_stream is not None:
# for sp overlap communication
ctx.bwd_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(ctx.bwd_stream):
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
elif args.enable_zbh1_pipeline:
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.bwd_stream is not None:
total_input.record_stream(ctx.bwd_stream)
grad_output.record_stream(ctx.bwd_stream)
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None

if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None

def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
Expand All @@ -393,6 +403,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
async_sp_all2all_stream=None
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
Expand Down Expand Up @@ -453,6 +464,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel,
async_sp_all2all_stream
]

if not linear_with_grad_accumulation_and_async_allreduce.warned:
Expand Down Expand Up @@ -607,7 +619,6 @@ def __init__(self, input_size, output_size, *,
"cannot be enabled at the same time."
)


def forward(self,
input_: torch.Tensor,
weight: Optional[torch.Tensor] = None):
Expand Down Expand Up @@ -706,9 +717,10 @@ def __init__(self, input_size: int, output_size: int, *,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
moe=False, enable_expert_tensor_parallelism=False):
moe=False, enable_expert_tensor_parallelism=False, ds_sp_async_stream=None):
torch.nn.Module.__init__(self)

self.ds_sp_async_stream = ds_sp_async_stream

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
Expand Down Expand Up @@ -784,13 +796,15 @@ def forward(self, input_):
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.

output_parallel = linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel=False,
async_sp_all2all_stream=self.ds_sp_async_stream
)

# All-reduce across all the partitions.
Expand Down
106 changes: 79 additions & 27 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,14 @@ class ParallelAttention(MegatronModule):
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""

sp_stream=None

def get_sp_stream(self):
if not self.ds_sp_overlap:
return None
if ParallelAttention.sp_stream is None:
ParallelAttention.sp_stream=get_accelerator().Stream()
return ParallelAttention.sp_stream
def __init__(self, config, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
Expand All @@ -524,7 +531,8 @@ def __init__(self, config, layer_number,
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.use_gqa = (self.num_attention_heads != self.num_key_value_heads)

self.split_qkv = args.split_qkv_linear
self.ds_sp_overlap = args.ds_sequence_parallel_overlap_comm
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \
args.use_flash_attn_builder) \
and attention_type == AttnType.self_attn \
Expand Down Expand Up @@ -577,13 +585,31 @@ def __init__(self, config, layer_number,

# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)
if not self.split_qkv:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)

else:
linear_configs = [
("query_linear", projection_size),
("key_linear", kv_projection_size),
("value_linear", kv_projection_size),
]

for attr_name, output_size in linear_configs:
setattr(self, attr_name, tensor_parallel.ColumnParallelLinear(
config.hidden_size,
output_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False
))
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
Expand Down Expand Up @@ -614,12 +640,14 @@ def __init__(self, config, layer_number,
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:

assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0

self.dist_attn = DistributedAttention(
local_attn,
parallel_state.get_sequence_parallel_group(),
gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0)
gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0,sp_stream=self.get_sp_stream())
# flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension.
else:
if self.use_flash_attn:
Expand All @@ -636,7 +664,9 @@ def __init__(self, config, layer_number,
init_method=config.output_layer_init_method,
bias=args.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True)
skip_bias_add=True,
ds_sp_async_stream=self.get_sp_stream()
)


def _checkpointed_attention_forward(self, query_layer, key_layer,
Expand Down Expand Up @@ -722,22 +752,41 @@ def forward(self, hidden_states, attention_mask,
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

if self.enable_ds_sequence_parallel:
assert self.projection_size == self.kv_projection_size
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim)
if self.sequence_parallel or not self.enable_ds_sequence_parallel:
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
each_hidden_size = mixed_x_layer.shape[-1] // 3
query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim)

if not self.split_qkv:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

if self.enable_ds_sequence_parallel:
assert self.projection_size == self.kv_projection_size
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim)
if self.sequence_parallel or not self.enable_ds_sequence_parallel:
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
each_hidden_size = mixed_x_layer.shape[-1] // 3
query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim)
else:
assert self.ds_sp_overlap, """
Currently, the split_qkv operation is only applicable
when ds_sp_overlap is enabled.
"""
self.get_sp_stream().wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(self.get_sp_stream()):
query_layer,_ = self.query_linear(hidden_states)
query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1)
fwd_query_layer_done_event = get_accelerator().Event()
fwd_query_layer_done_event.record(self.get_sp_stream())
key_layer,_ = self.key_linear(hidden_states)
key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1)

fwd_key_layer_done_event = get_accelerator().Event()
fwd_key_layer_done_event.record(self.get_sp_stream())
value_layer,_ = self.value_linear(hidden_states)
value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1)

# Repeat kv
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
Expand Down Expand Up @@ -833,6 +882,9 @@ def forward(self, hidden_states, attention_mask,
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

if self.enable_ds_sequence_parallel:
if self.ds_sp_overlap:
key_layer.done_event=fwd_key_layer_done_event
query_layer.done_event=fwd_query_layer_done_event
batch_dim_idx = 1
if self.use_flash_attn:
if not self.use_flash_attn_triton:
Expand Down