From 866873a36f5a38c889d4c8c3c8263cc6b97814ff Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 28 Jan 2025 15:45:42 -0800 Subject: [PATCH] process_group/ManagedProcessGroup: ensure quorum and PG is configured before operations (#83) --- torchft/process_group.py | 5 ++++- torchft/process_group_test.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index e689288..d1d2cbe 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -19,7 +19,6 @@ import logging import queue import threading -from abc import ABC from datetime import timedelta from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union @@ -507,6 +506,10 @@ def __init__(self, manager: "Manager") -> None: self._manager = manager def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + # Ensure we have a valid quorum and are configured before trying to do + # any work. + self._manager.wait_quorum() + if self._manager.errored() is not None: return _DummyWork(tensors) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 417c32e..d24f838 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -368,6 +368,7 @@ def test_managed_process_group(self) -> None: self.assertEqual(manager.report_error.call_count, 0) self.assertEqual(manager.wrap_future.call_count, 1) + self.assertEqual(manager.wait_quorum.call_count, 1) class DeviceMeshTest(TestCase):