Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check tests #633

Merged
merged 9 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/test_oml/test_ddp/test_loader_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .utils import run_in_ddp


@pytest.mark.skip(reason="Dead locks may appear when running in CI")
@pytest.mark.long
@pytest.mark.parametrize("n_labels_sampler", [2, 5])
@pytest.mark.parametrize("n_instances_sampler", [2, 5])
Expand Down Expand Up @@ -156,7 +155,6 @@ def check_patching_balance_batch_sampler(
assert len(set(outputs_from_epochs)) == len(outputs_from_epochs)


@pytest.mark.skip(reason="Dead locks may appear when running in CI")
@pytest.mark.long
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
Expand Down
28 changes: 11 additions & 17 deletions tests/test_oml/test_ddp/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import inspect
import pickle
import sys
import socket
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Tuple

import torch
from torch.distributed import destroy_process_group, init_process_group
from torch.multiprocessing import spawn

from oml.const import TMP_PATH
from oml.utils.io import calc_hash
from oml.utils.misc import set_global_seed


Expand All @@ -23,22 +19,20 @@ def assert_signature(fn: Callable) -> None: # type: ignore
raise ValueError(f"The function '{fn.__name__}' should have 'rank' and 'world_size' as the first two arguments")


def generate_connection_filename(world_size: int, fn: Callable, *args: Tuple[Any, ...]) -> Path: # type: ignore
python_info = sys.version_info
python_str = f"{python_info.major}.{python_info.minor}.{python_info.micro}"
args_hash = calc_hash(pickle.dumps(args))
filename = TMP_PATH / f"{fn.__name__}_{world_size}_{python_str}_{args_hash}"
return filename
def get_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def fn_ddp_wrapper(
rank: int, connection_file: Path, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore
rank: int, port: int, world_size: int, fn: Callable, *args: Tuple[Any, ...] # type: ignore
) -> Any: # type: ignore
init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
init_method=f"file://{connection_file}",
init_method=f"tcp://127.0.0.1:{port}",
timeout=timedelta(seconds=120),
)
set_global_seed(1)
Expand All @@ -50,12 +44,12 @@ def fn_ddp_wrapper(

def run_in_ddp(world_size: int, fn: Callable, args: Tuple[Any, ...] = ()) -> Any: # type: ignore
assert_signature(fn)
set_global_seed(1)
torch.set_num_threads(1)
if world_size > 1:
connection_file = generate_connection_filename(world_size, fn, *args)
connection_file.unlink(missing_ok=True)
connection_file.parent.mkdir(exist_ok=True, parents=True)
port = get_free_port()
# note, 'spawn' automatically passes 'rank' to its first argument
spawn(fn_ddp_wrapper, args=(connection_file, world_size, fn, *args), nprocs=world_size, join=True)
spawn(fn_ddp_wrapper, args=(port, world_size, fn, *args), nprocs=world_size, join=True)
else:
set_global_seed(1)
return fn(0, world_size, *args)