Skip to content

Support xccl distributed backend #3034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion server/text_generation_server/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.distributed import ProcessGroup
from datetime import timedelta
from loguru import logger
from packaging import version
from text_generation_server.utils.import_utils import SYSTEM

# Tensor Parallelism settings
Expand Down Expand Up @@ -45,6 +46,12 @@ def rank(self):
return self._rank


def _is_xccl_available():
if version.parse(torch.__version__).release >= version.parse("2.7").release:
return torch.distributed.distributed_c10d.is_xccl_available()
return False


def initialize_torch_distributed():
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
Expand All @@ -54,11 +61,20 @@ def initialize_torch_distributed():
device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
device = "cuda"
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=120)
elif SYSTEM == "xpu" and _is_xccl_available():
assert WORLD_SIZE <= torch.xpu.device_count(), "Each process is one gpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
device = "xpu"
backend = "xccl"
options = None
else:
device = None
backend = "gloo"
options = None

Expand All @@ -81,7 +97,8 @@ def initialize_torch_distributed():
pg_options=options,
)
else:
device = torch.device(f"cuda:{RANK}")
if device:
device = torch.device(f"{device}:{RANK}")
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
Expand All @@ -90,6 +107,7 @@ def initialize_torch_distributed():
pg_options=options,
device_id=device,
)
logger.info(f"torch.distributed initialized with {backend} backend for rank {RANK}")
else:
logger.warning("torch.distributed is already initialized.")

Expand Down