diff --git a/proto/torchft.proto b/proto/torchft.proto index 7e248e5..5213c17 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -74,6 +74,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; + bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 2d4de57..cf98c99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -154,6 +154,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, + init_sync: true, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index bd14783..d94c9e4 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -285,7 +285,7 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?; Ok(Response::new(reply)) } @@ -381,6 +381,7 @@ fn compute_quorum_results( replica_id: &str, rank: i64, quorum: &Quorum, + init_sync: bool, ) -> Result { let mut participants = quorum.participants.clone(); participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); @@ -423,20 +424,25 @@ fn compute_quorum_results( // Compute recovery assignments - // Nodes are recovering if: - // 1. not at the max step - // 2. max_step == 0 and not the primary replica - let all_recover_dst_ranks: Vec = participants - .iter() - .enumerate() - .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { - Some(i) - } else { - None - } - }) - .collect(); + let all_recover_dst_ranks = if init_sync { + // Nodes are recovering if + // 1. not at the max step + // 2. max_step == 0 and not the primary replica + participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect() + } else { + Vec::::new() + }; + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants .iter() @@ -604,6 +610,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -663,6 +670,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); @@ -768,13 +776,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -782,7 +790,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -842,21 +850,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum)?; + let results = compute_quorum_results("replica_3", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -864,7 +872,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -872,4 +880,86 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_2".to_string(), + address: "addr_2".to_string(), + store_address: "store_addr_2".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_3".to_string(), + address: "addr_3".to_string(), + store_address: "store_addr_3".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_4".to_string(), + address: "addr_4".to_string(), + store_address: "store_addr_4".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 3); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + Ok(()) + } } diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index b4afde6..9f182a3 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -10,6 +10,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, + init_sync: bool = True, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( diff --git a/torchft/manager.py b/torchft/manager.py index 0da48d0..50587ee 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -372,6 +372,7 @@ def start_quorum( allow_heal: bool = True, shrink_only: bool = False, timeout: Optional[timedelta] = None, + init_sync: bool = True, ) -> None: """ .. note:: @@ -407,6 +408,7 @@ def start_quorum( allow_heal=allow_heal, shrink_only=shrink_only, quorum_timeout=timeout or self._quorum_timeout, + init_sync=init_sync, ) if not self._use_async_quorum: self.wait_quorum() @@ -431,7 +433,7 @@ def wait_quorum(self) -> None: self._quorum_future.result() def _async_quorum( - self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta + self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta, init_sync: bool ) -> None: quorum = self._client._quorum( rank=self._rank, @@ -439,6 +441,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, + init_sync=init_sync, ) quorum_id = quorum.quorum_id diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 05793e1..f5fd865 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -579,3 +579,35 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: client_mock().should_commit.call_args.kwargs["timeout"], timedelta(seconds=23), ) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_skip_init(self, client_mock: MagicMock) -> None: + manager = self._create_manager(use_async_quorum=False) + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=True) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=False) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], False + )