Skip to content

Add option to skip init sync #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ message ManagerQuorumRequest {
int64 step = 2;
string checkpoint_metadata = 3;
bool shrink_only = 4;
bool init_sync = 5;
}

message ManagerQuorumResponse {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl ManagerClient {
step: step,
checkpoint_metadata: checkpoint_metadata,
shrink_only: shrink_only,
init_sync: true,
});

// This timeout is processed on the server side so we also enable
Expand Down
134 changes: 112 additions & 22 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ impl ManagerService for Arc<Manager> {

info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank);

let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?;
let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?;

Ok(Response::new(reply))
}
Expand Down Expand Up @@ -381,6 +381,7 @@ fn compute_quorum_results(
replica_id: &str,
rank: i64,
quorum: &Quorum,
init_sync: bool,
) -> Result<ManagerQuorumResponse, Status> {
let mut participants = quorum.participants.clone();
participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));
Expand Down Expand Up @@ -423,20 +424,25 @@ fn compute_quorum_results(

// Compute recovery assignments

// Nodes are recovering if:
// 1. not at the max step
// 2. max_step == 0 and not the primary replica
let all_recover_dst_ranks: Vec<usize> = participants
.iter()
.enumerate()
.filter_map(|(i, p)| {
if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id {
Some(i)
} else {
None
}
})
.collect();
let all_recover_dst_ranks = if init_sync {
// Nodes are recovering if
// 1. not at the max step
// 2. max_step == 0 and not the primary replica
participants
.iter()
.enumerate()
.filter_map(|(i, p)| {
if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id {
Some(i)
} else {
None
}
})
.collect()
} else {
Vec::<usize>::new()
};

let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::<HashSet<_>>();
let up_to_date_ranks: Vec<usize> = participants
.iter()
Expand Down Expand Up @@ -604,6 +610,7 @@ mod tests {
step: 123,
checkpoint_metadata: "addr".to_string(),
shrink_only: false,
init_sync: true,
});
request.set_timeout(Duration::from_secs(10));
let resp = client.quorum(request).await?.into_inner();
Expand Down Expand Up @@ -663,6 +670,7 @@ mod tests {
step: 0,
checkpoint_metadata: "addr".to_string(),
shrink_only: false,
init_sync: true,
});
request.set_timeout(Duration::from_secs(10));

Expand Down Expand Up @@ -768,21 +776,21 @@ mod tests {

// rank 0

let results = compute_quorum_results("replica_0", 0, &quorum)?;
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 0);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![1]);

let results = compute_quorum_results("replica_1", 0, &quorum)?;
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
assert!(results.heal);
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, Some(0));
assert_eq!(results.recover_dst_ranks, Vec::<i64>::new());

// rank 1 assignments should be offset from rank 0 above and the primary

let results = compute_quorum_results("replica_1", 1, &quorum)?;
let results = compute_quorum_results("replica_1", 1, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
Expand Down Expand Up @@ -842,34 +850,116 @@ mod tests {

// rank 0

let results = compute_quorum_results("replica_0", 0, &quorum)?;
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
assert!(results.heal);
assert_eq!(results.recover_src_manager_address, "addr_1".to_string());
assert_eq!(results.replica_rank, 0);
assert_eq!(results.recover_src_rank, Some(1));
assert!(results.recover_dst_ranks.is_empty());

let results = compute_quorum_results("replica_1", 0, &quorum)?;
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.recover_src_manager_address, "".to_string());
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![0, 4]);

let results = compute_quorum_results("replica_3", 0, &quorum)?;
let results = compute_quorum_results("replica_3", 0, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 3);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![2]);

// rank 1 assignments should be offset from rank 0 above

let results = compute_quorum_results("replica_1", 1, &quorum)?;
let results = compute_quorum_results("replica_1", 1, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![2]);

Ok(())
}

#[tokio::test]
async fn test_compute_quorum_results_skip_init_sync() -> Result<()> {
let quorum = Quorum {
quorum_id: 1,
participants: vec![
QuorumMember {
replica_id: "replica_0".to_string(),
address: "addr_0".to_string(),
store_address: "store_addr_0".to_string(),
step: 0,
world_size: 1,
shrink_only: false,
},
QuorumMember {
replica_id: "replica_1".to_string(),
address: "addr_1".to_string(),
store_address: "store_addr_1".to_string(),
step: 1,
world_size: 1,
shrink_only: false,
},
QuorumMember {
replica_id: "replica_2".to_string(),
address: "addr_2".to_string(),
store_address: "store_addr_2".to_string(),
step: 0,
world_size: 1,
shrink_only: false,
},
QuorumMember {
replica_id: "replica_3".to_string(),
address: "addr_3".to_string(),
store_address: "store_addr_3".to_string(),
step: 1,
world_size: 1,
shrink_only: false,
},
QuorumMember {
replica_id: "replica_4".to_string(),
address: "addr_4".to_string(),
store_address: "store_addr_4".to_string(),
step: 0,
world_size: 1,
shrink_only: false,
},
],
created: None,
};

// rank 0

let results = compute_quorum_results("replica_0", 0, &quorum, false)?;
assert!(!results.heal);
assert_eq!(results.recover_src_manager_address, "".to_string());
assert_eq!(results.replica_rank, 0);
assert_eq!(results.recover_src_rank, None);
assert!(results.recover_dst_ranks.is_empty());

let results = compute_quorum_results("replica_1", 0, &quorum, false)?;
assert!(!results.heal);
assert_eq!(results.recover_src_manager_address, "".to_string());
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
assert!(results.recover_dst_ranks.is_empty());

let results = compute_quorum_results("replica_3", 0, &quorum, false)?;
assert!(!results.heal);
assert_eq!(results.recover_src_manager_address, "".to_string());
assert_eq!(results.replica_rank, 3);
assert_eq!(results.recover_src_rank, None);
assert!(results.recover_dst_ranks.is_empty());

let results = compute_quorum_results("replica_1", 1, &quorum, false)?;
assert!(!results.heal);
assert_eq!(results.recover_src_manager_address, "".to_string());
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
assert!(results.recover_dst_ranks.is_empty());

Ok(())
}
}
1 change: 1 addition & 0 deletions torchft/_torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ManagerClient:
checkpoint_metadata: str,
shrink_only: bool,
timeout: timedelta,
init_sync: bool = True,
) -> QuorumResult: ...
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
def should_commit(
Expand Down
5 changes: 4 additions & 1 deletion torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def start_quorum(
allow_heal: bool = True,
shrink_only: bool = False,
timeout: Optional[timedelta] = None,
init_sync: bool = True,
) -> None:
"""
.. note::
Expand Down Expand Up @@ -407,6 +408,7 @@ def start_quorum(
allow_heal=allow_heal,
shrink_only=shrink_only,
quorum_timeout=timeout or self._quorum_timeout,
init_sync=init_sync,
)
if not self._use_async_quorum:
self.wait_quorum()
Expand All @@ -431,14 +433,15 @@ def wait_quorum(self) -> None:
self._quorum_future.result()

def _async_quorum(
self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta
self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta, init_sync: bool
) -> None:
quorum = self._client._quorum(
rank=self._rank,
step=self._step,
checkpoint_metadata=self._checkpoint_transport.metadata(),
shrink_only=shrink_only,
timeout=quorum_timeout,
init_sync=init_sync,
)

quorum_id = quorum.quorum_id
Expand Down
32 changes: 32 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,35 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
client_mock().should_commit.call_args.kwargs["timeout"],
timedelta(seconds=23),
)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
manager = self._create_manager(use_async_quorum=False)

quorum = QuorumResult()
quorum.quorum_id = 123
quorum.replica_rank = 1
quorum.replica_world_size = 2
quorum.recover_src_manager_address = "manager address"
quorum.store_address = f"localhost:{self.store.port}"
quorum.max_step = 1
quorum.max_rank = 1
quorum.max_world_size = 2
quorum.heal = False

client_mock()._quorum.return_value = quorum

manager.start_quorum()
self.assertEqual(
client_mock()._quorum.call_args.kwargs["init_sync"], True
)

manager.start_quorum(init_sync=True)
self.assertEqual(
client_mock()._quorum.call_args.kwargs["init_sync"], True
)

manager.start_quorum(init_sync=False)
self.assertEqual(
client_mock()._quorum.call_args.kwargs["init_sync"], False
)