@@ -27,7 +27,7 @@ pub mod torchftpb {
2727}
2828
2929use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
30- use crate :: torchftpb:: { CheckpointAddressRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
30+ use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
3131use pyo3:: prelude:: * ;
3232
3333#[ pyclass]
@@ -113,15 +113,15 @@ impl ManagerClient {
113113 py : Python < ' _ > ,
114114 rank : i64 ,
115115 step : i64 ,
116- checkpoint_server_addr : String ,
116+ checkpoint_metadata : String ,
117117 shrink_only : bool ,
118118 timeout : Duration ,
119- ) -> Result < ( i64 , i64 , i64 , String , String , i64 , Option < i64 > , i64 , bool ) , StatusError > {
119+ ) -> Result < QuorumResult , StatusError > {
120120 py. allow_threads ( move || {
121121 let mut request = tonic:: Request :: new ( ManagerQuorumRequest {
122122 rank : rank,
123123 step : step,
124- checkpoint_server_addr : checkpoint_server_addr ,
124+ checkpoint_metadata : checkpoint_metadata ,
125125 shrink_only : shrink_only,
126126 } ) ;
127127
@@ -131,38 +131,40 @@ impl ManagerClient {
131131
132132 let response = self . runtime . block_on ( self . client . clone ( ) . quorum ( request) ) ?;
133133 let resp = response. into_inner ( ) ;
134- Ok ( (
135- resp. quorum_id ,
136- resp. replica_rank ,
137- resp. replica_world_size ,
138- resp. address ,
139- resp. store_address ,
140- resp. max_step ,
141- resp. max_rank ,
142- resp. max_world_size ,
143- resp. heal ,
144- ) )
134+ Ok ( QuorumResult {
135+ quorum_id : resp. quorum_id ,
136+ replica_rank : resp. replica_rank ,
137+ replica_world_size : resp. replica_world_size ,
138+ recover_src_manager_address : resp. recover_src_manager_address ,
139+ recover_src_rank : resp. recover_src_rank ,
140+ recover_dst_ranks : resp. recover_dst_ranks ,
141+ store_address : resp. store_address ,
142+ max_step : resp. max_step ,
143+ max_rank : resp. max_rank ,
144+ max_world_size : resp. max_world_size ,
145+ heal : resp. heal ,
146+ } )
145147 } )
146148 }
147149
148- fn checkpoint_address (
150+ fn checkpoint_metadata (
149151 & self ,
150152 py : Python < ' _ > ,
151153 rank : i64 ,
152154 timeout : Duration ,
153155 ) -> Result < String , StatusError > {
154156 py. allow_threads ( move || {
155- let mut request = tonic:: Request :: new ( CheckpointAddressRequest { rank : rank } ) ;
157+ let mut request = tonic:: Request :: new ( CheckpointMetadataRequest { rank : rank } ) ;
156158
157159 // This timeout is processed on the server side so we also enable
158160 // keep alives to detect server health.
159161 request. set_timeout ( timeout) ;
160162
161163 let response = self
162164 . runtime
163- . block_on ( self . client . clone ( ) . checkpoint_address ( request) ) ?;
165+ . block_on ( self . client . clone ( ) . checkpoint_metadata ( request) ) ?;
164166 let resp = response. into_inner ( ) ;
165- Ok ( resp. checkpoint_server_address )
167+ Ok ( resp. checkpoint_metadata )
166168 } )
167169 }
168170
@@ -194,6 +196,41 @@ impl ManagerClient {
194196 }
195197}
196198
199+ #[ pyclass( get_all, set_all) ]
200+ struct QuorumResult {
201+ quorum_id : i64 ,
202+ replica_rank : i64 ,
203+ replica_world_size : i64 ,
204+ recover_src_manager_address : String ,
205+ recover_src_rank : Option < i64 > ,
206+ recover_dst_ranks : Vec < i64 > ,
207+ store_address : String ,
208+ max_step : i64 ,
209+ max_rank : Option < i64 > ,
210+ max_world_size : i64 ,
211+ heal : bool ,
212+ }
213+
214+ #[ pymethods]
215+ impl QuorumResult {
216+ #[ new]
217+ fn new ( ) -> Self {
218+ Self {
219+ quorum_id : 0 ,
220+ replica_rank : 0 ,
221+ replica_world_size : 1 ,
222+ recover_src_manager_address : "" . to_string ( ) ,
223+ recover_src_rank : None ,
224+ recover_dst_ranks : Vec :: new ( ) ,
225+ store_address : "" . to_string ( ) ,
226+ max_step : 0 ,
227+ max_rank : None ,
228+ max_world_size : 1 ,
229+ heal : false ,
230+ }
231+ }
232+ }
233+
197234fn reset_python_signals ( py : Python < ' _ > ) -> PyResult < ( ) > {
198235 // clear python signal handlers
199236 // signal.signal(signal.SIGINT, signal.SIG_DFL)
@@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
319356 m. add_class :: < Manager > ( ) ?;
320357 m. add_class :: < ManagerClient > ( ) ?;
321358 m. add_class :: < Lighthouse > ( ) ?;
359+ m. add_class :: < QuorumResult > ( ) ?;
322360 m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
323361
324362 Ok ( ( ) )
0 commit comments