Skip to content

Commit 8f021e1

Browse files
authored
Fix nccl future execution (#126)
1 parent 06bab30 commit 8f021e1

File tree

2 files changed

+128
-13
lines changed

2 files changed

+128
-13
lines changed

torchft/manager_integ_test.py

+102-7
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,26 @@
22
import logging
33
import threading
44
import time
5+
import traceback
56
from concurrent.futures import ThreadPoolExecutor, as_completed
67
from contextlib import ExitStack, contextmanager
78
from dataclasses import dataclass, field
89
from datetime import timedelta
9-
from typing import Any, Dict, Generator, List, Protocol, Set, Tuple
10+
from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, TypeVar
1011
from unittest import TestCase
1112

1213
import torch
1314
import torch.distributed as dist
1415
from parameterized import parameterized
1516
from torch import nn, optim
17+
from torch._dynamo.utils import timed
1618

1719
from torchft._torchft import LighthouseServer
1820
from torchft.ddp import DistributedDataParallel
1921
from torchft.local_sgd import DiLoCo, LocalSGD
2022
from torchft.manager import Manager
2123
from torchft.optim import OptimizerWrapper
22-
from torchft.process_group import ProcessGroupGloo
24+
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
2325

2426
logger: logging.Logger = logging.getLogger(__name__)
2527

@@ -69,10 +71,14 @@ def check(self, rank: int, step: int) -> None:
6971
raise InjectedFailure(f"injected failure {rank=} {step=}")
7072

7173

72-
class TrainLoop(Protocol):
74+
# R for an arbitrary return type
75+
R = TypeVar("R", covariant=True)
76+
77+
78+
class TrainLoop(Protocol[R]):
7379
def __call__(
7480
self, rank: int, store_port: int, device: torch.device, runner: "Runner"
75-
) -> Dict[str, Dict[str, object]]: ...
81+
) -> R: ...
7682

7783

7884
@dataclass
@@ -81,15 +87,15 @@ class Runner:
8187
num_replicas: int
8288
lighthouse_address: str
8389
failure_injector: FailureInjector
84-
train_loop: TrainLoop
90+
train_loop: TrainLoop[object]
8591

8692
use_cuda: bool = False
8793
world_size: int = 1
8894
attempts: int = 3
8995
manager_args: Dict[str, object] = field(default_factory=dict)
9096
train_loop_args: Dict[str, Any] = field(default_factory=dict)
9197

92-
def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
98+
def _replica_main(self) -> List[object]:
9399
store = dist.TCPStore(
94100
host_name="localhost",
95101
port=0,
@@ -131,7 +137,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
131137

132138
return [fut.result() for fut in futures]
133139

134-
def run_replica(self) -> List[Dict[str, Dict[str, object]]]:
140+
def run_replica(self) -> List[object]:
135141
for i in range(self.attempts):
136142
try:
137143
print(
@@ -391,3 +397,92 @@ def test_quorum_timeout(self) -> None:
391397
"status: Cancelled, message.*Timeout expired",
392398
):
393399
manager.should_commit(timeout=timedelta(seconds=0.01))
400+
401+
@parameterized.expand(
402+
[
403+
(True,), # Test with CUDA
404+
(False,), # Test without CUDA (CPU)
405+
]
406+
)
407+
def test_manager_allreduce(self, use_cuda: bool) -> None:
408+
# Skip the test if use_cuda is True and there are not enough GPUs
409+
if use_cuda and torch.cuda.device_count() < 2:
410+
self.skipTest("Not enough GPUs for CUDA test")
411+
412+
# manager supports allreduce but we found an issue where the future callback is getting called
413+
# before the allreduce is complete. This test is to ensure that the callback has stream synchronization
414+
lighthouse = LighthouseServer(
415+
bind="[::]:0",
416+
min_replicas=2,
417+
)
418+
num_replicas = 2
419+
futures = []
420+
421+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
422+
for replica_id in range(num_replicas):
423+
failure_injector = FailureInjector()
424+
runner = Runner(
425+
replica_id=replica_id,
426+
num_replicas=num_replicas,
427+
lighthouse_address=lighthouse.address(),
428+
failure_injector=failure_injector,
429+
train_loop=all_reduce_callback,
430+
use_cuda=use_cuda,
431+
)
432+
futures.append(executor.submit(runner.run_replica))
433+
434+
results = []
435+
for fut in as_completed(futures):
436+
try:
437+
results.append(fut.result()[0])
438+
except Exception as e:
439+
print(e, flush=True)
440+
traceback.print_exc()
441+
raise
442+
443+
lighthouse.shutdown()
444+
445+
print(results)
446+
r0, r1 = results
447+
torch.testing.assert_close(r0, r1, check_device=False)
448+
449+
450+
def all_reduce_callback(
451+
rank: int,
452+
store_port: int,
453+
device: torch.device,
454+
runner: Runner,
455+
) -> Optional[torch.Tensor]:
456+
with ExitStack() as stack:
457+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
458+
459+
if device.type == "cuda":
460+
pg = ProcessGroupBabyNCCL()
461+
else:
462+
pg = ProcessGroupGloo()
463+
manager = Manager(
464+
pg=pg,
465+
min_replica_size=2,
466+
use_async_quorum=False,
467+
load_state_dict=lambda x: None,
468+
state_dict=lambda: None,
469+
replica_id=str(runner.replica_id),
470+
store_addr="localhost",
471+
store_port=store_port,
472+
rank=rank,
473+
world_size=runner.world_size,
474+
lighthouse_addr=runner.lighthouse_address,
475+
port=19530 + runner.replica_id,
476+
timeout=timedelta(seconds=10),
477+
quorum_timeout=timedelta(seconds=10),
478+
# pyre-fixme[6]: Incompatible parameter type
479+
**runner.manager_args,
480+
)
481+
stack.callback(lambda: manager.shutdown(wait=False))
482+
483+
manager.start_quorum()
484+
t1 = torch.ones((1, 3), device=device)
485+
fut = manager.allreduce(t1)
486+
fut.wait()
487+
return t1
488+
return None

torchft/process_group.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -1093,10 +1093,12 @@ def _worker(
10931093

10941094
args = _PickleSafeOptions.unsafe_args(args)
10951095
fn = getattr(pg, func_name)
1096+
10961097
work[op_id] = _OpMetadata(
10971098
work=fn(*args, **kwargs),
10981099
stream=stream,
10991100
)
1101+
11001102
elif cmd == "wait":
11011103
op_id, timeout = cast(tuple[int, timedelta], op[1:])
11021104

@@ -1126,15 +1128,29 @@ def _worker(
11261128
del work[op_id]
11271129
elif cmd == "future":
11281130
op_id: int = cast(int, op[1])
1131+
metadata: _OpMetadata = work[op_id]
11291132

1130-
def callback(fut: Future[object]) -> None:
1133+
def callback(fut: Future[object], metadata: _OpMetadata) -> None:
11311134
try:
1132-
fut.wait()
1133-
future_pipe.send((op_id, _FUTURE_RESULT, None))
1135+
# create an event after the collective has been issued
1136+
# to wait on this before we call "future"
1137+
with metadata.set_stream():
1138+
fut.wait()
1139+
event = (
1140+
torch.cuda.current_stream().record_event(
1141+
torch.cuda.Event(interprocess=True)
1142+
)
1143+
if metadata.stream is not None
1144+
else None
1145+
)
1146+
1147+
future_pipe.send((op_id, _FUTURE_RESULT, None, event))
11341148
except Exception as e:
1135-
future_pipe.send((op_id, _FUTURE_EXCEPTION, e))
1149+
future_pipe.send((op_id, _FUTURE_EXCEPTION, e, None))
11361150

1137-
work[op_id].work.get_future().add_done_callback(callback)
1151+
metadata.work.get_future().add_done_callback(
1152+
lambda fut: callback(fut, metadata)
1153+
)
11381154
elif cmd == "num_active_work":
11391155
req_pipe.send(len(work))
11401156
else:
@@ -1153,11 +1169,15 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
11531169
except TimeoutError:
11541170
continue
11551171

1156-
op_id, mode, data = cast(Tuple[int, str, object], cmd)
1172+
op_id, mode, data, event = cast(
1173+
Tuple[int, str, object, Optional[torch.cuda.Event]], cmd
1174+
)
11571175
with self._futures_lock:
11581176
fut = self._futures[op_id]
11591177
del self._futures[op_id]
11601178
if mode == _FUTURE_RESULT:
1179+
if event is not None:
1180+
event.wait()
11611181
fut.set_result(data)
11621182
elif mode == _FUTURE_EXCEPTION:
11631183
fut.set_exception(data)

0 commit comments

Comments
 (0)