Skip to content

Commit

Permalink
* reworked ddp func to avoid dead locks in tests
Browse files Browse the repository at this point in the history
* reworked ddp func to avoid dead locks in tests
  • Loading branch information
DaloroAT authored Feb 2, 2025
1 parent 8300607 commit 317e376
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
2 changes: 0 additions & 2 deletions tests/test_oml/test_ddp/test_loader_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .utils import run_in_ddp


@pytest.mark.skip(reason="Dead locks may appear when running in CI")
@pytest.mark.long
@pytest.mark.parametrize("n_labels_sampler", [2, 5])
@pytest.mark.parametrize("n_instances_sampler", [2, 5])
Expand Down Expand Up @@ -156,7 +155,6 @@ def check_patching_balance_batch_sampler(
assert len(set(outputs_from_epochs)) == len(outputs_from_epochs)


@pytest.mark.skip(reason="Dead locks may appear when running in CI")
@pytest.mark.long
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
Expand Down
28 changes: 11 additions & 17 deletions tests/test_oml/test_ddp/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import inspect
import pickle
import sys
import socket
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Tuple

import torch
from torch.distributed import destroy_process_group, init_process_group
from torch.multiprocessing import spawn

from oml.const import TMP_PATH
from oml.utils.io import calc_hash
from oml.utils.misc import set_global_seed


Expand All @@ -23,22 +19,20 @@ def assert_signature(fn: Callable) -> None: # type: ignore
raise ValueError(f"The function '{fn.__name__}' should have 'rank' and 'world_size' as the first two arguments")


def generate_connection_filename(world_size: int, fn: Callable, *args: Tuple[Any, ...]) -> Path: # type: ignore
python_info = sys.version_info
python_str = f"{python_info.major}.{python_info.minor}.{python_info.micro}"
args_hash = calc_hash(pickle.dumps(args))
filename = TMP_PATH / f"{fn.__name__}_{world_size}_{python_str}_{args_hash}"
return filename
def get_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def fn_ddp_wrapper(
rank: int, connection_file: Path, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore
rank: int, port: int, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore
) -> Any: # type: ignore
init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
init_method=f"file://{connection_file}",
init_method=f"tcp://127.0.0.1:{port}",
timeout=timedelta(seconds=120),
)
set_global_seed(1)
Expand All @@ -50,12 +44,12 @@ def fn_ddp_wrapper(

def run_in_ddp(world_size: int, fn: Callable, args: Tuple[Any, ...] = ()) -> Any: # type: ignore
assert_signature(fn)
set_global_seed(1)
torch.set_num_threads(1)
if world_size > 1:
connection_file = generate_connection_filename(world_size, fn, *args)
connection_file.unlink(missing_ok=True)
connection_file.parent.mkdir(exist_ok=True, parents=True)
port = get_free_port()
# note, 'spawn' automatically passes 'rank' to its first argument
spawn(fn_ddp_wrapper, args=(connection_file, world_size, fn, *args), nprocs=world_size, join=True)
spawn(fn_ddp_wrapper, args=(port, world_size, fn, *args), nprocs=world_size, join=True)
else:
set_global_seed(1)
return fn(0, world_size, *args)

0 comments on commit 317e376

Please sign in to comment.