diff --git a/tests/test_oml/test_ddp/test_loader_patcher.py b/tests/test_oml/test_ddp/test_loader_patcher.py index 8bc5d65b..9d8558c4 100644 --- a/tests/test_oml/test_ddp/test_loader_patcher.py +++ b/tests/test_oml/test_ddp/test_loader_patcher.py @@ -17,7 +17,6 @@ from .utils import run_in_ddp -@pytest.mark.skip(reason="Dead locks may appear when running in CI") @pytest.mark.long @pytest.mark.parametrize("n_labels_sampler", [2, 5]) @pytest.mark.parametrize("n_instances_sampler", [2, 5]) @@ -156,7 +155,6 @@ def check_patching_balance_batch_sampler( assert len(set(outputs_from_epochs)) == len(outputs_from_epochs) -@pytest.mark.skip(reason="Dead locks may appear when running in CI") @pytest.mark.long @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) diff --git a/tests/test_oml/test_ddp/utils.py b/tests/test_oml/test_ddp/utils.py index 53b9c728..8c10ec9a 100644 --- a/tests/test_oml/test_ddp/utils.py +++ b/tests/test_oml/test_ddp/utils.py @@ -1,16 +1,12 @@ import inspect -import pickle -import sys +import socket from datetime import timedelta -from pathlib import Path from typing import Any, Callable, Tuple import torch from torch.distributed import destroy_process_group, init_process_group from torch.multiprocessing import spawn -from oml.const import TMP_PATH -from oml.utils.io import calc_hash from oml.utils.misc import set_global_seed @@ -23,22 +19,20 @@ def assert_signature(fn: Callable) -> None: # type: ignore raise ValueError(f"The function '{fn.__name__}' should have 'rank' and 'world_size' as the first two arguments") -def generate_connection_filename(world_size: int, fn: Callable, *args: Tuple[Any, ...]) -> Path: # type: ignore - python_info = sys.version_info - python_str = f"{python_info.major}.{python_info.minor}.{python_info.micro}" - args_hash = calc_hash(pickle.dumps(args)) - filename = TMP_PATH / f"{fn.__name__}_{world_size}_{python_str}_{args_hash}" - return filename +def get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] def fn_ddp_wrapper( - rank: int, connection_file: Path, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore + rank: int, port: int, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore ) -> Any: # type: ignore init_process_group( backend="gloo", rank=rank, world_size=world_size, - init_method=f"file://{connection_file}", + init_method=f"tcp://127.0.0.1:{port}", timeout=timedelta(seconds=120), ) set_global_seed(1) @@ -50,12 +44,12 @@ def fn_ddp_wrapper( def run_in_ddp(world_size: int, fn: Callable, args: Tuple[Any, ...] = ()) -> Any: # type: ignore assert_signature(fn) + set_global_seed(1) + torch.set_num_threads(1) if world_size > 1: - connection_file = generate_connection_filename(world_size, fn, *args) - connection_file.unlink(missing_ok=True) - connection_file.parent.mkdir(exist_ok=True, parents=True) + port = get_free_port() # note, 'spawn' automatically passes 'rank' to its first argument - spawn(fn_ddp_wrapper, args=(connection_file, world_size, fn, *args), nprocs=world_size, join=True) + spawn(fn_ddp_wrapper, args=(port, world_size, fn, *args), nprocs=world_size, join=True) else: set_global_seed(1) return fn(0, world_size, *args)