2
2
import logging
3
3
import threading
4
4
import time
5
+ import traceback
5
6
from concurrent .futures import ThreadPoolExecutor , as_completed
6
7
from contextlib import ExitStack , contextmanager
7
8
from dataclasses import dataclass , field
8
9
from datetime import timedelta
9
- from typing import Any , Dict , Generator , List , Protocol , Set , Tuple
10
+ from typing import Any , Dict , Generator , List , Optional , Protocol , Set , Tuple , TypeVar
10
11
from unittest import TestCase
11
12
12
13
import torch
13
14
import torch .distributed as dist
14
15
from parameterized import parameterized
15
16
from torch import nn , optim
17
+ from torch ._dynamo .utils import timed
16
18
17
19
from torchft ._torchft import LighthouseServer
18
20
from torchft .ddp import DistributedDataParallel
19
21
from torchft .local_sgd import DiLoCo , LocalSGD
20
22
from torchft .manager import Manager
21
23
from torchft .optim import OptimizerWrapper
22
- from torchft .process_group import ProcessGroupGloo
24
+ from torchft .process_group import ProcessGroupBabyNCCL , ProcessGroupGloo
23
25
24
26
logger : logging .Logger = logging .getLogger (__name__ )
25
27
@@ -69,10 +71,14 @@ def check(self, rank: int, step: int) -> None:
69
71
raise InjectedFailure (f"injected failure { rank = } { step = } " )
70
72
71
73
72
- class TrainLoop (Protocol ):
74
+ # R for an arbitrary return type
75
+ R = TypeVar ("R" , covariant = True )
76
+
77
+
78
+ class TrainLoop (Protocol [R ]):
73
79
def __call__ (
74
80
self , rank : int , store_port : int , device : torch .device , runner : "Runner"
75
- ) -> Dict [ str , Dict [ str , object ]] : ...
81
+ ) -> R : ...
76
82
77
83
78
84
@dataclass
@@ -81,15 +87,15 @@ class Runner:
81
87
num_replicas : int
82
88
lighthouse_address : str
83
89
failure_injector : FailureInjector
84
- train_loop : TrainLoop
90
+ train_loop : TrainLoop [ object ]
85
91
86
92
use_cuda : bool = False
87
93
world_size : int = 1
88
94
attempts : int = 3
89
95
manager_args : Dict [str , object ] = field (default_factory = dict )
90
96
train_loop_args : Dict [str , Any ] = field (default_factory = dict )
91
97
92
- def _replica_main (self ) -> List [Dict [ str , Dict [ str , object ]] ]:
98
+ def _replica_main (self ) -> List [object ]:
93
99
store = dist .TCPStore (
94
100
host_name = "localhost" ,
95
101
port = 0 ,
@@ -131,7 +137,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
131
137
132
138
return [fut .result () for fut in futures ]
133
139
134
- def run_replica (self ) -> List [Dict [ str , Dict [ str , object ]] ]:
140
+ def run_replica (self ) -> List [object ]:
135
141
for i in range (self .attempts ):
136
142
try :
137
143
print (
@@ -391,3 +397,92 @@ def test_quorum_timeout(self) -> None:
391
397
"status: Cancelled, message.*Timeout expired" ,
392
398
):
393
399
manager .should_commit (timeout = timedelta (seconds = 0.01 ))
400
+
401
+ @parameterized .expand (
402
+ [
403
+ (True ,), # Test with CUDA
404
+ (False ,), # Test without CUDA (CPU)
405
+ ]
406
+ )
407
+ def test_manager_allreduce (self , use_cuda : bool ) -> None :
408
+ # Skip the test if use_cuda is True and there are not enough GPUs
409
+ if use_cuda and torch .cuda .device_count () < 2 :
410
+ self .skipTest ("Not enough GPUs for CUDA test" )
411
+
412
+ # manager supports allreduce but we found an issue where the future callback is getting called
413
+ # before the allreduce is complete. This test is to ensure that the callback has stream synchronization
414
+ lighthouse = LighthouseServer (
415
+ bind = "[::]:0" ,
416
+ min_replicas = 2 ,
417
+ )
418
+ num_replicas = 2
419
+ futures = []
420
+
421
+ with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
422
+ for replica_id in range (num_replicas ):
423
+ failure_injector = FailureInjector ()
424
+ runner = Runner (
425
+ replica_id = replica_id ,
426
+ num_replicas = num_replicas ,
427
+ lighthouse_address = lighthouse .address (),
428
+ failure_injector = failure_injector ,
429
+ train_loop = all_reduce_callback ,
430
+ use_cuda = use_cuda ,
431
+ )
432
+ futures .append (executor .submit (runner .run_replica ))
433
+
434
+ results = []
435
+ for fut in as_completed (futures ):
436
+ try :
437
+ results .append (fut .result ()[0 ])
438
+ except Exception as e :
439
+ print (e , flush = True )
440
+ traceback .print_exc ()
441
+ raise
442
+
443
+ lighthouse .shutdown ()
444
+
445
+ print (results )
446
+ r0 , r1 = results
447
+ torch .testing .assert_close (r0 , r1 , check_device = False )
448
+
449
+
450
+ def all_reduce_callback (
451
+ rank : int ,
452
+ store_port : int ,
453
+ device : torch .device ,
454
+ runner : Runner ,
455
+ ) -> Optional [torch .Tensor ]:
456
+ with ExitStack () as stack :
457
+ print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
458
+
459
+ if device .type == "cuda" :
460
+ pg = ProcessGroupBabyNCCL ()
461
+ else :
462
+ pg = ProcessGroupGloo ()
463
+ manager = Manager (
464
+ pg = pg ,
465
+ min_replica_size = 2 ,
466
+ use_async_quorum = False ,
467
+ load_state_dict = lambda x : None ,
468
+ state_dict = lambda : None ,
469
+ replica_id = str (runner .replica_id ),
470
+ store_addr = "localhost" ,
471
+ store_port = store_port ,
472
+ rank = rank ,
473
+ world_size = runner .world_size ,
474
+ lighthouse_addr = runner .lighthouse_address ,
475
+ port = 19530 + runner .replica_id ,
476
+ timeout = timedelta (seconds = 10 ),
477
+ quorum_timeout = timedelta (seconds = 10 ),
478
+ # pyre-fixme[6]: Incompatible parameter type
479
+ ** runner .manager_args ,
480
+ )
481
+ stack .callback (lambda : manager .shutdown (wait = False ))
482
+
483
+ manager .start_quorum ()
484
+ t1 = torch .ones ((1 , 3 ), device = device )
485
+ fut = manager .allreduce (t1 )
486
+ fut .wait ()
487
+ return t1
488
+ return None
0 commit comments