Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fp32 all_reduce and reduce_scatter #389

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
@@ -29,10 +29,34 @@ def wait(self) -> None:
DUMMY_HANDLE_CONST = DummyAsyncCommHandle()


class WrappedHandle:
"""
Handle precision conversion when async all_reduce or reduce_scatter
"""

def __init__(self, handle, output, dtype):
self.handle = handle
self.output = output
self.dtype = dtype

def wait(self):
self.handle.wait()
if gpc.config.reduce_comm_dtype != self.dtype:
self.output.data = self.output.to(self.dtype)
self.output = None


# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
if async_op is False:
input_ = input_.to(gpc.config.model.dtype)
else:
handle = WrappedHandle(handle=handle, output=input_, dtype=gpc.config.model.dtype)
return input_, handle


@@ -122,7 +146,11 @@ def _reduce(input_, parallel_mode):
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(input_, group=group)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.model.dtype)

return input_

@@ -241,6 +269,9 @@ def reduce_scatter_raw(
if world_size <= 1:
return input_, None

if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)

shape_list = list(input_.shape)
shape_list[reduce_dim] = shape_list[reduce_dim] // world_size

@@ -251,6 +282,11 @@ def reduce_scatter_raw(
).contiguous()

handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
if async_op is False:
output = output.to(gpc.config.model.dtype)
else:
handle = WrappedHandle(handle=handle, output=output, dtype=gpc.config.model.dtype)
return output, handle


4 changes: 4 additions & 0 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
@@ -136,7 +136,11 @@ def _train_one_batch(
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= scale_loss
loss /= scale_loss
21 changes: 21 additions & 0 deletions internlm/core/scheduler/pipeline_scheduler_1f1b.py
Original file line number Diff line number Diff line change
@@ -320,7 +320,11 @@ def _forward_step(
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
@@ -458,7 +462,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
@@ -662,7 +670,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
@@ -883,7 +895,12 @@ def _forward_step(self, engine, chunk_id, input_obj=None):
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)

moe_loss /= self.num_microbatches

if self._accum_moe_loss is not None:
@@ -1414,7 +1431,11 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
output, label = (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.model.dtype)
accum_moe_loss = self._accum_moe_loss

accum_loss = self._accum_loss
5 changes: 4 additions & 1 deletion internlm/core/scheduler/pipeline_scheduler_zb.py
Original file line number Diff line number Diff line change
@@ -351,7 +351,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
@@ -901,7 +905,6 @@ def _run_steady_loop(
else:
next_unit_chunk_id = 1

# import pdb; pdb.set_trace()
if unit_step == num_units_stage1 - 1:
chunk0_B_need_recv_prev_chunk0_output = False
else:
16 changes: 16 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
@@ -484,6 +484,22 @@ def args_sanity_check():
if "early_reduce_scatter_release" not in gpc.config.parallel.expert_weight:
gpc.config.parallel.expert_weight["early_reduce_scatter_release"] = True

# the comm_dtype for reduce communication
if gpc.config.get("reduce_comm_dtype", None) is None:
gpc.config.reduce_comm_dtype = gpc.config.model.dtype
else:
if gpc.config.reduce_comm_dtype == "torch.bfloat16":
gpc.config.reduce_comm_dtype = torch.bfloat16
elif gpc.config.reduce_comm_dtype == "torch.float32":
gpc.config.reduce_comm_dtype = torch.float32
else:
assert gpc.config.reduce_comm_dtype in [
"torch.bfloat16",
"torch.float32",
]
if gpc.config.model.dtype == torch.float32:
assert gpc.config.reduce_comm_dtype == gpc.config.model.dtype

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
assert (
15 changes: 10 additions & 5 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import WrappedHandle
from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
from internlm.utils.logger import get_logger
from internlm.utils.parallel import (
@@ -106,13 +107,12 @@ def reduce_tensor(
# use the original dtype
# if dtype is None:
assert dtype is None
dtype = tensor.dtype
dtype = gpc.config.reduce_comm_dtype
tensor_dtype = tensor.dtype

# cast the data to specified dtype for reduce/all-reduce
# if tensor.dtype != dtype:
# tensor_to_reduce = tensor.to(dtype)
# else:
# tensor_to_reduce = tensor
if tensor_dtype != dtype:
tensor = tensor.to(dtype)

# world_size = gpc.get_world_size(parallel_mode)
# tensor.div_(world_size)
@@ -129,6 +129,11 @@ def reduce_tensor(
global_rank = ranks_in_group[dst_rank]
handle = dist.reduce(tensor=tensor, dst=global_rank, group=group, op=op_type, async_op=async_op)

if tensor_dtype != dtype:
if async_op:
handle = WrappedHandle(handle=handle, output=tensor, dtype=tensor_dtype)
else:
tensor = tensor.to(tensor_dtype)
return handle