1
1
import inspect
2
- import pickle
3
- import sys
2
+ import socket
4
3
from datetime import timedelta
5
- from pathlib import Path
6
4
from typing import Any , Callable , Tuple
7
5
8
6
import torch
9
7
from torch .distributed import destroy_process_group , init_process_group
10
8
from torch .multiprocessing import spawn
11
9
12
- from oml .const import TMP_PATH
13
- from oml .utils .io import calc_hash
14
10
from oml .utils .misc import set_global_seed
15
11
16
12
@@ -23,22 +19,20 @@ def assert_signature(fn: Callable) -> None: # type: ignore
23
19
raise ValueError (f"The function '{ fn .__name__ } ' should have 'rank' and 'world_size' as the first two arguments" )
24
20
25
21
26
- def generate_connection_filename (world_size : int , fn : Callable , * args : Tuple [Any , ...]) -> Path : # type: ignore
27
- python_info = sys .version_info
28
- python_str = f"{ python_info .major } .{ python_info .minor } .{ python_info .micro } "
29
- args_hash = calc_hash (pickle .dumps (args ))
30
- filename = TMP_PATH / f"{ fn .__name__ } _{ world_size } _{ python_str } _{ args_hash } "
31
- return filename
22
+ def get_free_port () -> int :
23
+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
24
+ s .bind (("" , 0 ))
25
+ return s .getsockname ()[1 ]
32
26
33
27
34
28
def fn_ddp_wrapper (
35
- rank : int , connection_file : Path , world_size : int , fn : Callable , * args : Tuple [Any , ...] # type: ignore
29
+ rank : int , port : int , world_size : int , fn : Callable , * args : Tuple [Any , ...] # type: ignore
36
30
) -> Any : # type: ignore
37
31
init_process_group (
38
32
backend = "gloo" ,
39
33
rank = rank ,
40
34
world_size = world_size ,
41
- init_method = f"file ://{ connection_file } " ,
35
+ init_method = f"tcp ://127.0.0.1: { port } " ,
42
36
timeout = timedelta (seconds = 120 ),
43
37
)
44
38
set_global_seed (1 )
@@ -50,12 +44,12 @@ def fn_ddp_wrapper(
50
44
51
45
def run_in_ddp (world_size : int , fn : Callable , args : Tuple [Any , ...] = ()) -> Any : # type: ignore
52
46
assert_signature (fn )
47
+ set_global_seed (1 )
48
+ torch .set_num_threads (1 )
53
49
if world_size > 1 :
54
- connection_file = generate_connection_filename (world_size , fn , * args )
55
- connection_file .unlink (missing_ok = True )
56
- connection_file .parent .mkdir (exist_ok = True , parents = True )
50
+ port = get_free_port ()
57
51
# note, 'spawn' automatically passes 'rank' to its first argument
58
- spawn (fn_ddp_wrapper , args = (connection_file , world_size , fn , * args ), nprocs = world_size , join = True )
52
+ spawn (fn_ddp_wrapper , args = (port , world_size , fn , * args ), nprocs = world_size , join = True )
59
53
else :
60
54
set_global_seed (1 )
61
55
return fn (0 , world_size , * args )
0 commit comments