Skip to content

Commit 317e376

Browse files
authored
* reworked ddp func to avoid dead locks in tests
* reworked ddp func to avoid dead locks in tests
1 parent 8300607 commit 317e376

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

tests/test_oml/test_ddp/test_loader_patcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .utils import run_in_ddp
1818

1919

20-
@pytest.mark.skip(reason="Dead locks may appear when running in CI")
2120
@pytest.mark.long
2221
@pytest.mark.parametrize("n_labels_sampler", [2, 5])
2322
@pytest.mark.parametrize("n_instances_sampler", [2, 5])
@@ -156,7 +155,6 @@ def check_patching_balance_batch_sampler(
156155
assert len(set(outputs_from_epochs)) == len(outputs_from_epochs)
157156

158157

159-
@pytest.mark.skip(reason="Dead locks may appear when running in CI")
160158
@pytest.mark.long
161159
@pytest.mark.parametrize("shuffle", [True, False])
162160
@pytest.mark.parametrize("drop_last", [True, False])

tests/test_oml/test_ddp/utils.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import inspect
2-
import pickle
3-
import sys
2+
import socket
43
from datetime import timedelta
5-
from pathlib import Path
64
from typing import Any, Callable, Tuple
75

86
import torch
97
from torch.distributed import destroy_process_group, init_process_group
108
from torch.multiprocessing import spawn
119

12-
from oml.const import TMP_PATH
13-
from oml.utils.io import calc_hash
1410
from oml.utils.misc import set_global_seed
1511

1612

@@ -23,22 +19,20 @@ def assert_signature(fn: Callable) -> None: # type: ignore
2319
raise ValueError(f"The function '{fn.__name__}' should have 'rank' and 'world_size' as the first two arguments")
2420

2521

26-
def generate_connection_filename(world_size: int, fn: Callable, *args: Tuple[Any, ...]) -> Path: # type: ignore
27-
python_info = sys.version_info
28-
python_str = f"{python_info.major}.{python_info.minor}.{python_info.micro}"
29-
args_hash = calc_hash(pickle.dumps(args))
30-
filename = TMP_PATH / f"{fn.__name__}_{world_size}_{python_str}_{args_hash}"
31-
return filename
22+
def get_free_port() -> int:
23+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
24+
s.bind(("", 0))
25+
return s.getsockname()[1]
3226

3327

3428
def fn_ddp_wrapper(
35-
rank: int, connection_file: Path, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore
29+
rank: int, port: int, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore
3630
) -> Any: # type: ignore
3731
init_process_group(
3832
backend="gloo",
3933
rank=rank,
4034
world_size=world_size,
41-
init_method=f"file://{connection_file}",
35+
init_method=f"tcp://127.0.0.1:{port}",
4236
timeout=timedelta(seconds=120),
4337
)
4438
set_global_seed(1)
@@ -50,12 +44,12 @@ def fn_ddp_wrapper(
5044

5145
def run_in_ddp(world_size: int, fn: Callable, args: Tuple[Any, ...] = ()) -> Any: # type: ignore
5246
assert_signature(fn)
47+
set_global_seed(1)
48+
torch.set_num_threads(1)
5349
if world_size > 1:
54-
connection_file = generate_connection_filename(world_size, fn, *args)
55-
connection_file.unlink(missing_ok=True)
56-
connection_file.parent.mkdir(exist_ok=True, parents=True)
50+
port = get_free_port()
5751
# note, 'spawn' automatically passes 'rank' to its first argument
58-
spawn(fn_ddp_wrapper, args=(connection_file, world_size, fn, *args), nprocs=world_size, join=True)
52+
spawn(fn_ddp_wrapper, args=(port, world_size, fn, *args), nprocs=world_size, join=True)
5953
else:
6054
set_global_seed(1)
6155
return fn(0, world_size, *args)

0 commit comments

Comments
 (0)