@@ -27,7 +27,7 @@ pub mod torchftpb {
27
27
}
28
28
29
29
use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
30
- use crate :: torchftpb:: { CheckpointAddressRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
30
+ use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
31
31
use pyo3:: prelude:: * ;
32
32
33
33
#[ pyclass]
@@ -113,15 +113,15 @@ impl ManagerClient {
113
113
py : Python < ' _ > ,
114
114
rank : i64 ,
115
115
step : i64 ,
116
- checkpoint_server_addr : String ,
116
+ checkpoint_metadata : String ,
117
117
shrink_only : bool ,
118
118
timeout : Duration ,
119
- ) -> Result < ( i64 , i64 , i64 , String , String , i64 , Option < i64 > , i64 , bool ) , StatusError > {
119
+ ) -> Result < QuorumResult , StatusError > {
120
120
py. allow_threads ( move || {
121
121
let mut request = tonic:: Request :: new ( ManagerQuorumRequest {
122
122
rank : rank,
123
123
step : step,
124
- checkpoint_server_addr : checkpoint_server_addr ,
124
+ checkpoint_metadata : checkpoint_metadata ,
125
125
shrink_only : shrink_only,
126
126
} ) ;
127
127
@@ -131,38 +131,40 @@ impl ManagerClient {
131
131
132
132
let response = self . runtime . block_on ( self . client . clone ( ) . quorum ( request) ) ?;
133
133
let resp = response. into_inner ( ) ;
134
- Ok ( (
135
- resp. quorum_id ,
136
- resp. replica_rank ,
137
- resp. replica_world_size ,
138
- resp. address ,
139
- resp. store_address ,
140
- resp. max_step ,
141
- resp. max_rank ,
142
- resp. max_world_size ,
143
- resp. heal ,
144
- ) )
134
+ Ok ( QuorumResult {
135
+ quorum_id : resp. quorum_id ,
136
+ replica_rank : resp. replica_rank ,
137
+ replica_world_size : resp. replica_world_size ,
138
+ recover_src_manager_address : resp. recover_src_manager_address ,
139
+ recover_src_rank : resp. recover_src_rank ,
140
+ recover_dst_ranks : resp. recover_dst_ranks ,
141
+ store_address : resp. store_address ,
142
+ max_step : resp. max_step ,
143
+ max_rank : resp. max_rank ,
144
+ max_world_size : resp. max_world_size ,
145
+ heal : resp. heal ,
146
+ } )
145
147
} )
146
148
}
147
149
148
- fn checkpoint_address (
150
+ fn checkpoint_metadata (
149
151
& self ,
150
152
py : Python < ' _ > ,
151
153
rank : i64 ,
152
154
timeout : Duration ,
153
155
) -> Result < String , StatusError > {
154
156
py. allow_threads ( move || {
155
- let mut request = tonic:: Request :: new ( CheckpointAddressRequest { rank : rank } ) ;
157
+ let mut request = tonic:: Request :: new ( CheckpointMetadataRequest { rank : rank } ) ;
156
158
157
159
// This timeout is processed on the server side so we also enable
158
160
// keep alives to detect server health.
159
161
request. set_timeout ( timeout) ;
160
162
161
163
let response = self
162
164
. runtime
163
- . block_on ( self . client . clone ( ) . checkpoint_address ( request) ) ?;
165
+ . block_on ( self . client . clone ( ) . checkpoint_metadata ( request) ) ?;
164
166
let resp = response. into_inner ( ) ;
165
- Ok ( resp. checkpoint_server_address )
167
+ Ok ( resp. checkpoint_metadata )
166
168
} )
167
169
}
168
170
@@ -194,6 +196,41 @@ impl ManagerClient {
194
196
}
195
197
}
196
198
199
+ #[ pyclass( get_all, set_all) ]
200
+ struct QuorumResult {
201
+ quorum_id : i64 ,
202
+ replica_rank : i64 ,
203
+ replica_world_size : i64 ,
204
+ recover_src_manager_address : String ,
205
+ recover_src_rank : Option < i64 > ,
206
+ recover_dst_ranks : Vec < i64 > ,
207
+ store_address : String ,
208
+ max_step : i64 ,
209
+ max_rank : Option < i64 > ,
210
+ max_world_size : i64 ,
211
+ heal : bool ,
212
+ }
213
+
214
+ #[ pymethods]
215
+ impl QuorumResult {
216
+ #[ new]
217
+ fn new ( ) -> Self {
218
+ Self {
219
+ quorum_id : 0 ,
220
+ replica_rank : 0 ,
221
+ replica_world_size : 1 ,
222
+ recover_src_manager_address : "" . to_string ( ) ,
223
+ recover_src_rank : None ,
224
+ recover_dst_ranks : Vec :: new ( ) ,
225
+ store_address : "" . to_string ( ) ,
226
+ max_step : 0 ,
227
+ max_rank : None ,
228
+ max_world_size : 1 ,
229
+ heal : false ,
230
+ }
231
+ }
232
+ }
233
+
197
234
fn reset_python_signals ( py : Python < ' _ > ) -> PyResult < ( ) > {
198
235
// clear python signal handlers
199
236
// signal.signal(signal.SIGINT, signal.SIG_DFL)
@@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
319
356
m. add_class :: < Manager > ( ) ?;
320
357
m. add_class :: < ManagerClient > ( ) ?;
321
358
m. add_class :: < Lighthouse > ( ) ?;
359
+ m. add_class :: < QuorumResult > ( ) ?;
322
360
m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
323
361
324
362
Ok ( ( ) )
0 commit comments