11import inspect
2- import pickle
3- import sys
2+ import socket
43from datetime import timedelta
5- from pathlib import Path
64from typing import Any , Callable , Tuple
75
86import torch
97from torch .distributed import destroy_process_group , init_process_group
108from torch .multiprocessing import spawn
119
12- from oml .const import TMP_PATH
13- from oml .utils .io import calc_hash
1410from oml .utils .misc import set_global_seed
1511
1612
@@ -23,22 +19,20 @@ def assert_signature(fn: Callable) -> None: # type: ignore
2319 raise ValueError (f"The function '{ fn .__name__ } ' should have 'rank' and 'world_size' as the first two arguments" )
2420
2521
26- def generate_connection_filename (world_size : int , fn : Callable , * args : Tuple [Any , ...]) -> Path : # type: ignore
27- python_info = sys .version_info
28- python_str = f"{ python_info .major } .{ python_info .minor } .{ python_info .micro } "
29- args_hash = calc_hash (pickle .dumps (args ))
30- filename = TMP_PATH / f"{ fn .__name__ } _{ world_size } _{ python_str } _{ args_hash } "
31- return filename
22+ def get_free_port () -> int :
23+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
24+ s .bind (("" , 0 ))
25+ return s .getsockname ()[1 ]
3226
3327
3428def fn_ddp_wrapper (
35- rank : int , connection_file : Path , world_size : int , fn : Callable , * args : Tuple [Any , ...] # type: ignore
29+ rank : int , port : int , world_size : int , fn : Callable , * args : Tuple [Any , ...] # type: ignore
3630) -> Any : # type: ignore
3731 init_process_group (
3832 backend = "gloo" ,
3933 rank = rank ,
4034 world_size = world_size ,
41- init_method = f"file ://{ connection_file } " ,
35+ init_method = f"tcp ://127.0.0.1: { port } " ,
4236 timeout = timedelta (seconds = 120 ),
4337 )
4438 set_global_seed (1 )
@@ -50,12 +44,12 @@ def fn_ddp_wrapper(
5044
5145def run_in_ddp (world_size : int , fn : Callable , args : Tuple [Any , ...] = ()) -> Any : # type: ignore
5246 assert_signature (fn )
47+ set_global_seed (1 )
48+ torch .set_num_threads (1 )
5349 if world_size > 1 :
54- connection_file = generate_connection_filename (world_size , fn , * args )
55- connection_file .unlink (missing_ok = True )
56- connection_file .parent .mkdir (exist_ok = True , parents = True )
50+ port = get_free_port ()
5751 # note, 'spawn' automatically passes 'rank' to its first argument
58- spawn (fn_ddp_wrapper , args = (connection_file , world_size , fn , * args ), nprocs = world_size , join = True )
52+ spawn (fn_ddp_wrapper , args = (port , world_size , fn , * args ), nprocs = world_size , join = True )
5953 else :
6054 set_global_seed (1 )
6155 return fn (0 , world_size , * args )
0 commit comments