Skip to content

Add max retries to quorum to avoid live locks #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
heartbeat_interval: timedelta = timedelta(milliseconds=100),
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
init_sync: bool = True,
max_retries: Optional[int] = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -147,6 +148,8 @@ def __init__(
init_sync: whether to synchronize the model weights on step 0. If
all of the model weights are initialized identically via
``torch.set_seed`` you should set this to False.
max_retries: the maximum number of consecutive should_commit failures to allow
before raising an exception. If None, will retry indefinitely.
"""
self._load_state_dict = load_state_dict
self._user_state_dict = state_dict
Expand All @@ -157,6 +160,8 @@ def __init__(
self._connect_timeout = connect_timeout
self._world_size_mode = world_size_mode
self._init_sync = init_sync
self._max_retries = max_retries
self._commit_failures = 0

store_addr = store_addr or os.environ["MASTER_ADDR"]
store_port = store_port or int(os.environ["MASTER_PORT"])
Expand Down Expand Up @@ -595,8 +600,13 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:

This should only be called once per step.

If max_retries is set and should_commit fails that many times consecutively,
this method will raise a RuntimeError to prevent indefinite failure loops.

Returns:
True if the optimizer should be stepped, False otherwise
Raises:
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
"""
for work in self._pending_work:
# check at the beginning of since .wait() may trigger errors
Expand Down Expand Up @@ -638,6 +648,17 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
if should_commit:
self._step += 1
self._batches_committed += self.num_participants()
self._commit_failures = 0 # Reset failure counter on success
else:
self._commit_failures += 1
# Check if we've hit max retries
if (
self._max_retries is not None
and self._commit_failures > self._max_retries
):
msg = f"should_commit failed {self._commit_failures} times consecutively, exceeding max_retries={self._max_retries}"
self._logger.exception(msg)
raise RuntimeError(msg)

return should_commit

Expand Down
55 changes: 52 additions & 3 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class TestManager(TestCase):
manager: Optional[Manager] # pyre-fixme[13]: never initialized

def tearDown(self) -> None:
manager = self.manager
if manager is not None:
manager.shutdown(wait=False)
# Manager cleanup might be handled by _create_manager
if hasattr(self, "manager") and self.manager is not None:
self.manager.shutdown(wait=False)

def _create_manager(
self,
Expand All @@ -41,6 +41,7 @@ def _create_manager(
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
timeout: timedelta = timedelta(seconds=10),
init_sync: bool = True,
max_retries: Optional[int] = None,
) -> Manager:
pg = create_autospec(ProcessGroup)
pg.errored.return_value = None
Expand Down Expand Up @@ -69,6 +70,7 @@ def _create_manager(
world_size_mode=world_size_mode,
timeout=timeout,
init_sync=init_sync,
max_retries=max_retries,
)
self.manager = manager
return manager
Expand Down Expand Up @@ -645,3 +647,50 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
manager._init_sync = True
manager.start_quorum()
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_max_retries(self, client_mock: MagicMock) -> None:
# Create a manager with max_retries=2
manager = self._create_manager(max_retries=2)

# Setup quorum for testing
quorum = QuorumResult()
quorum.quorum_id = 123
quorum.replica_rank = 1
quorum.replica_world_size = 2
quorum.recover_src_manager_address = "manager address"
quorum.store_address = f"localhost:{self.store.port}"
quorum.max_step = 1
quorum.max_rank = 1
quorum.max_world_size = 2
quorum.heal = False
client_mock()._quorum.return_value = quorum

# Make should_commit always return False to simulate failures
client_mock().should_commit = MagicMock(return_value=False)

# Start quorum
manager.start_quorum()

# First failure
self.assertFalse(manager.should_commit())
self.assertEqual(manager._commit_failures, 1)

# Second failure
self.assertFalse(manager.should_commit())
self.assertEqual(manager._commit_failures, 2)

# Third failure - should raise exception
with self.assertRaises(RuntimeError) as context:
manager.should_commit()

self.assertIn("exceeding max_retries=2", str(context.exception))
self.assertEqual(manager._commit_failures, 3)

# Now test that success resets the counter
manager._commit_failures = 2 # Reset to just before failure threshold
client_mock().should_commit = MagicMock(return_value=True) # Now succeed

# This should succeed and reset the counter
self.assertTrue(manager.should_commit())
self.assertEqual(manager._commit_failures, 0)