Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coordination: expose new low level torchft coordination API #84

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/coordination.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: torchft.coordination
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ the entire training job.
data
checkpointing
parameter_server
coordination


License
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dev = [

[tool.maturin]
features = ["pyo3/extension-module"]
module-name = "torchft._torchft"

[project.scripts]
torchft_lighthouse = "torchft.torchft:lighthouse_main"
Expand Down
2 changes: 1 addition & 1 deletion src/bin/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

use _torchft::lighthouse::{Lighthouse, LighthouseOpt};
use structopt::StructOpt;
use torchft::lighthouse::{Lighthouse, LighthouseOpt};

#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
async fn main() {
Expand Down
77 changes: 68 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,30 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
use pyo3::prelude::*;

/// ManagerServer is a GRPC server for the manager service.
/// There should be one manager server per replica group (typically running on
/// the rank 0 host). The individual ranks within a replica group should use
/// ManagerClient to communicate with the manager server and participate in
/// quorum operations.
///
/// Args:
/// replica_id (str): The ID of the replica group.
/// lighthouse_addr (str): The HTTP address of the lighthouse server.
/// hostname (str): The hostname of the manager server.
/// bind (str): The HTTP address to bind the server to.
/// store_addr (str): The HTTP address of the store server.
/// world_size (int): The world size of the replica group.
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
#[pyclass]
struct Manager {
struct ManagerServer {
handle: JoinHandle<Result<()>>,
manager: Arc<manager::Manager>,
_runtime: Runtime,
}

#[pymethods]
impl Manager {
impl ManagerServer {
#[new]
fn new(
py: Python<'_>,
Expand Down Expand Up @@ -74,17 +89,29 @@ impl Manager {
})
}

/// address returns the address of the manager server.
///
/// Returns:
/// str: The address of the manager server.
fn address(&self) -> PyResult<String> {
Ok(self.manager.address().to_string())
}

/// shutdown shuts down the manager server.
fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
})
}
}

/// ManagerClient is a GRPC client to the manager service.
///
/// It is used by the trainer to communicate with the ManagerServer.
///
/// Args:
/// addr (str): The HTTP address of the manager server.
/// connect_timeout (timedelta): The timeout for connecting to the manager server.
#[pyclass]
struct ManagerClient {
runtime: Runtime,
Expand All @@ -108,7 +135,7 @@ impl ManagerClient {
})
}

fn quorum(
fn _quorum(
&self,
py: Python<'_>,
rank: i64,
Expand Down Expand Up @@ -147,7 +174,7 @@ impl ManagerClient {
})
}

fn checkpoint_metadata(
fn _checkpoint_metadata(
&self,
py: Python<'_>,
rank: i64,
Expand All @@ -168,6 +195,20 @@ impl ManagerClient {
})
}

/// should_commit makes a request to the manager to determine if the trainer
/// should commit the current step. This waits until all ranks check in at
/// the specified step and will return false if any worker passes
/// ``should_commit=False``.
///
/// Args:
/// rank (int): The rank of the trainer.
/// step (int): The step of the trainer.
/// should_commit (bool): Whether the trainer should commit the current step.
/// timeout (timedelta): The timeout for the request. If the request
/// times out a TimeoutError is raised.
///
/// Returns:
/// bool: Whether the trainer should commit the current step.
fn should_commit(
&self,
py: Python<'_>,
Expand Down Expand Up @@ -263,15 +304,28 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
Ok(())
}

/// LighthouseServer is a GRPC server for the lighthouse service.
///
/// It is used to coordinate the ManagerServer for each replica group.
///
/// This entrypoint is primarily for testing and debugging purposes. The
/// ``torchft_lighthouse`` command is recommended for most use cases.
///
/// Args:
/// bind (str): The HTTP address to bind the server to.
/// min_replicas (int): The minimum number of replicas required to form a quorum.
/// join_timeout_ms (int): The timeout for joining the quorum.
/// quorum_tick_ms (int): The interval at which the quorum is checked.
/// heartbeat_timeout_ms (int): The timeout for heartbeats.
#[pyclass]
struct Lighthouse {
struct LighthouseServer {
lighthouse: Arc<lighthouse::Lighthouse>,
handle: JoinHandle<Result<()>>,
_runtime: Runtime,
}

#[pymethods]
impl Lighthouse {
impl LighthouseServer {
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
#[new]
fn new(
Expand Down Expand Up @@ -307,10 +361,15 @@ impl Lighthouse {
})
}

/// address returns the address of the lighthouse server.
///
/// Returns:
/// str: The address of the lighthouse server.
fn address(&self) -> PyResult<String> {
Ok(self.lighthouse.address().to_string())
}

/// shutdown shuts down the lighthouse server.
fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
Expand Down Expand Up @@ -339,7 +398,7 @@ impl From<Status> for StatusError {
}

#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
let mut log = stderrlog::new();
log.verbosity(2)
Expand All @@ -353,9 +412,9 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<Manager>()?;
m.add_class::<ManagerServer>()?;
m.add_class::<ManagerClient>()?;
m.add_class::<Lighthouse>()?;
m.add_class::<LighthouseServer>()?;
m.add_class::<QuorumResult>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;

Expand Down
8 changes: 4 additions & 4 deletions torchft/torchft.pyi → torchft/_torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ from typing import List, Optional

class ManagerClient:
def __init__(self, addr: str, connect_timeout: timedelta) -> None: ...
def quorum(
def _quorum(
self,
rank: int,
step: int,
checkpoint_metadata: str,
shrink_only: bool,
timeout: timedelta,
) -> QuorumResult: ...
def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
def should_commit(
self,
rank: int,
Expand All @@ -33,7 +33,7 @@ class QuorumResult:
max_world_size: int
heal: bool

class Manager:
class ManagerServer:
def __init__(
self,
replica_id: str,
Expand All @@ -48,7 +48,7 @@ class Manager:
def address(self) -> str: ...
def shutdown(self) -> None: ...

class Lighthouse:
class LighthouseServer:
def __init__(
self,
bind: str,
Expand Down
24 changes: 24 additions & 0 deletions torchft/coordination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Coordination (Low Level API)
============================

.. warning::
As torchft is still in development, the APIs in this module are subject to change.

This module exposes low level coordination APIs to allow you to build your own
custom fault tolerance algorithms on top of torchft.

If you're looking for a more complete solution, please use the other modules in
torchft.

This provides direct access to the Lighthouse and Manager servers and clients.
"""

from torchft._torchft import LighthouseServer, ManagerClient, ManagerServer


__all__ = [
"LighthouseServer",
"ManagerServer",
"ManagerClient",
]
19 changes: 19 additions & 0 deletions torchft/coordination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import inspect
from unittest import TestCase

from torchft.coordination import LighthouseServer, ManagerClient, ManagerServer


class TestCoordination(TestCase):
def test_coordination_docs(self) -> None:
classes = [
ManagerClient,
ManagerServer,
LighthouseServer,
]
for cls in classes:
self.assertIn("Args:", str(cls.__doc__), cls)
for name, method in inspect.getmembers(cls, predicate=inspect.ismethod):
if name.startswith("_"):
continue
self.assertIn("Args:", str(cls.__doc__), cls)
8 changes: 4 additions & 4 deletions torchft/lighthouse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch.distributed as dist

from torchft import Manager, ProcessGroupGloo
from torchft.torchft import Lighthouse
from torchft._torchft import LighthouseServer


class TestLighthouse(TestCase):
def test_join_timeout_behavior(self) -> None:
"""Test that join_timeout_ms affects joining behavior"""
# To test, we create a lighthouse with 100ms and 400ms join timeouts
# and measure the time taken to validate the quorum.
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=1,
join_timeout_ms=100,
Expand Down Expand Up @@ -52,14 +52,14 @@ def test_join_timeout_behavior(self) -> None:
if "manager" in locals():
manager.shutdown()

lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=1,
join_timeout_ms=400,
)

def test_heartbeat_timeout_ms_sanity(self) -> None:
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=1,
heartbeat_timeout_ms=100,
Expand Down
10 changes: 5 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
import torch
from torch.distributed import ReduceOp, TCPStore

from torchft._torchft import ManagerClient, ManagerServer
from torchft.checkpointing import CheckpointServer, CheckpointTransport
from torchft.futures import future_timeout
from torchft.torchft import Manager as _Manager, ManagerClient

if TYPE_CHECKING:
from torchft.process_group import ProcessGroup
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
wait_for_workers=False,
)
self._pg = pg
self._manager: Optional[_Manager] = None
self._manager: Optional[ManagerServer] = None

if rank == 0:
if port is None:
Expand All @@ -193,7 +193,7 @@ def __init__(
if replica_id is None:
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
self._manager = _Manager(
self._manager = ManagerServer(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
hostname=hostname,
Expand Down Expand Up @@ -424,7 +424,7 @@ def wait_quorum(self) -> None:
def _async_quorum(
self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta
) -> None:
quorum = self._client.quorum(
quorum = self._client._quorum(
rank=self._rank,
step=self._step,
checkpoint_metadata=self._checkpoint_transport.metadata(),
Expand Down Expand Up @@ -493,7 +493,7 @@ def _async_quorum(
primary_client = ManagerClient(
recover_src_manager_address, connect_timeout=self._connect_timeout
)
checkpoint_metadata = primary_client.checkpoint_metadata(
checkpoint_metadata = primary_client._checkpoint_metadata(
self._rank, timeout=self._timeout
)
recover_src_rank = quorum.recover_src_rank
Expand Down
Loading
Loading