Skip to content

Commit ccf74d4

Browse files
authored
checkpointing: use CheckpointTransport abstraction (#81)
1 parent e177f9c commit ccf74d4

9 files changed

+748
-294
lines changed

proto/torchft.proto

+15-13
Original file line numberDiff line numberDiff line change
@@ -72,30 +72,32 @@ service LighthouseService {
7272
message ManagerQuorumRequest {
7373
int64 rank = 1;
7474
int64 step = 2;
75-
string checkpoint_server_addr = 3;
75+
string checkpoint_metadata = 3;
7676
bool shrink_only = 4;
7777
}
7878

7979
message ManagerQuorumResponse {
8080
int64 quorum_id = 1;
81-
string address = 2;
82-
string store_address = 3;
81+
string recover_src_manager_address = 2;
82+
optional int64 recover_src_rank = 3;
83+
repeated int64 recover_dst_ranks = 4;
84+
string store_address = 5;
8385
// These are information for the replicas which are at the max step.
84-
int64 max_step = 4;
85-
optional int64 max_rank = 5;
86-
int64 max_world_size = 6;
86+
int64 max_step = 6;
87+
optional int64 max_rank = 7;
88+
int64 max_world_size = 8;
8789
// These are information for all replicas including behind replicas.
88-
int64 replica_rank = 7;
89-
int64 replica_world_size = 8;
90-
bool heal = 9;
90+
int64 replica_rank = 9;
91+
int64 replica_world_size = 10;
92+
bool heal = 11;
9193
}
9294

93-
message CheckpointAddressRequest {
95+
message CheckpointMetadataRequest {
9496
int64 rank = 1;
9597
}
9698

97-
message CheckpointAddressResponse {
98-
string checkpoint_server_address = 1;
99+
message CheckpointMetadataResponse {
100+
string checkpoint_metadata = 1;
99101
}
100102

101103
message ShouldCommitRequest {
@@ -114,7 +116,7 @@ message KillResponse {}
114116

115117
service ManagerService {
116118
rpc Quorum (ManagerQuorumRequest) returns (ManagerQuorumResponse);
117-
rpc CheckpointAddress(CheckpointAddressRequest) returns (CheckpointAddressResponse);
119+
rpc CheckpointMetadata(CheckpointMetadataRequest) returns (CheckpointMetadataResponse);
118120
rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
119121
rpc Kill(KillRequest) returns (KillResponse);
120122
}

src/lib.rs

+57-19
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub mod torchftpb {
2727
}
2828

2929
use crate::torchftpb::manager_service_client::ManagerServiceClient;
30-
use crate::torchftpb::{CheckpointAddressRequest, ManagerQuorumRequest, ShouldCommitRequest};
30+
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
3131
use pyo3::prelude::*;
3232

3333
#[pyclass]
@@ -113,15 +113,15 @@ impl ManagerClient {
113113
py: Python<'_>,
114114
rank: i64,
115115
step: i64,
116-
checkpoint_server_addr: String,
116+
checkpoint_metadata: String,
117117
shrink_only: bool,
118118
timeout: Duration,
119-
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
119+
) -> Result<QuorumResult, StatusError> {
120120
py.allow_threads(move || {
121121
let mut request = tonic::Request::new(ManagerQuorumRequest {
122122
rank: rank,
123123
step: step,
124-
checkpoint_server_addr: checkpoint_server_addr,
124+
checkpoint_metadata: checkpoint_metadata,
125125
shrink_only: shrink_only,
126126
});
127127

@@ -131,38 +131,40 @@ impl ManagerClient {
131131

132132
let response = self.runtime.block_on(self.client.clone().quorum(request))?;
133133
let resp = response.into_inner();
134-
Ok((
135-
resp.quorum_id,
136-
resp.replica_rank,
137-
resp.replica_world_size,
138-
resp.address,
139-
resp.store_address,
140-
resp.max_step,
141-
resp.max_rank,
142-
resp.max_world_size,
143-
resp.heal,
144-
))
134+
Ok(QuorumResult {
135+
quorum_id: resp.quorum_id,
136+
replica_rank: resp.replica_rank,
137+
replica_world_size: resp.replica_world_size,
138+
recover_src_manager_address: resp.recover_src_manager_address,
139+
recover_src_rank: resp.recover_src_rank,
140+
recover_dst_ranks: resp.recover_dst_ranks,
141+
store_address: resp.store_address,
142+
max_step: resp.max_step,
143+
max_rank: resp.max_rank,
144+
max_world_size: resp.max_world_size,
145+
heal: resp.heal,
146+
})
145147
})
146148
}
147149

148-
fn checkpoint_address(
150+
fn checkpoint_metadata(
149151
&self,
150152
py: Python<'_>,
151153
rank: i64,
152154
timeout: Duration,
153155
) -> Result<String, StatusError> {
154156
py.allow_threads(move || {
155-
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
157+
let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank });
156158

157159
// This timeout is processed on the server side so we also enable
158160
// keep alives to detect server health.
159161
request.set_timeout(timeout);
160162

161163
let response = self
162164
.runtime
163-
.block_on(self.client.clone().checkpoint_address(request))?;
165+
.block_on(self.client.clone().checkpoint_metadata(request))?;
164166
let resp = response.into_inner();
165-
Ok(resp.checkpoint_server_address)
167+
Ok(resp.checkpoint_metadata)
166168
})
167169
}
168170

@@ -194,6 +196,41 @@ impl ManagerClient {
194196
}
195197
}
196198

199+
#[pyclass(get_all, set_all)]
200+
struct QuorumResult {
201+
quorum_id: i64,
202+
replica_rank: i64,
203+
replica_world_size: i64,
204+
recover_src_manager_address: String,
205+
recover_src_rank: Option<i64>,
206+
recover_dst_ranks: Vec<i64>,
207+
store_address: String,
208+
max_step: i64,
209+
max_rank: Option<i64>,
210+
max_world_size: i64,
211+
heal: bool,
212+
}
213+
214+
#[pymethods]
215+
impl QuorumResult {
216+
#[new]
217+
fn new() -> Self {
218+
Self {
219+
quorum_id: 0,
220+
replica_rank: 0,
221+
replica_world_size: 1,
222+
recover_src_manager_address: "".to_string(),
223+
recover_src_rank: None,
224+
recover_dst_ranks: Vec::new(),
225+
store_address: "".to_string(),
226+
max_step: 0,
227+
max_rank: None,
228+
max_world_size: 1,
229+
heal: false,
230+
}
231+
}
232+
}
233+
197234
fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
198235
// clear python signal handlers
199236
// signal.signal(signal.SIGINT, signal.SIG_DFL)
@@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
319356
m.add_class::<Manager>()?;
320357
m.add_class::<ManagerClient>()?;
321358
m.add_class::<Lighthouse>()?;
359+
m.add_class::<QuorumResult>()?;
322360
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
323361

324362
Ok(())

0 commit comments

Comments
 (0)