@@ -189,17 +189,27 @@ def _run_multi_process_test_per_rank(
189
189
self .assertEqual (0 , p .exitcode )
190
190
191
191
192
+ def _wrapper_func_for_multiprocessing (args ): # pyre-ignore[2, 3]
193
+ """Wrapper function that unpacks arguments and calls the original func"""
194
+ func , rank , world_size , kwargs = args
195
+ kwargs ["rank" ] = rank
196
+ kwargs ["world_size" ] = world_size
197
+ return func (** kwargs )
198
+
199
+
200
+ # pyre-ignore[3]
192
201
def run_multi_process_func (
202
+ # pyre-ignore[2]
193
203
func : Callable [
194
204
[int , int , ...], # rank, world_size, ...
195
- None ,
205
+ Any , # Changed from None to Any to allow return values
196
206
],
197
207
multiprocessing_method : str = "spawn" ,
198
208
use_deterministic_algorithms : bool = True ,
199
209
world_size : int = 2 ,
200
210
# pyre-ignore
201
211
** kwargs ,
202
- ) -> None :
212
+ ) -> List [ Any ] :
203
213
""" """
204
214
os .environ ["MASTER_ADDR" ] = str ("localhost" )
205
215
os .environ ["MASTER_PORT" ] = str (get_free_port ())
@@ -215,22 +225,16 @@ def run_multi_process_func(
215
225
if world_size == 1 :
216
226
kwargs ["world_size" ] = 1
217
227
kwargs ["rank" ] = 0
218
- func (** kwargs )
219
- return
228
+ result = func (** kwargs )
229
+ return [result ]
230
+
220
231
ctx = multiprocessing .get_context (multiprocessing_method )
221
- processes = []
222
- for rank in range (world_size ):
223
- kwargs ["rank" ] = rank
224
- kwargs ["world_size" ] = world_size
225
- p = ctx .Process (
226
- target = func ,
227
- name = f"rank{ rank } " ,
228
- kwargs = kwargs ,
229
- )
230
- p .start ()
231
- processes .append (p )
232
232
233
- for p in processes :
234
- p .join ()
235
- if p .exitcode != 0 :
236
- print (p )
233
+ # Prepare arguments for each process
234
+ args_list = [(func , rank , world_size , kwargs .copy ()) for rank in range (world_size )]
235
+
236
+ # Create a pool of worker processes for each rank
237
+ with ctx .Pool (processes = world_size ) as pool :
238
+ results = pool .map (_wrapper_func_for_multiprocessing , args_list )
239
+
240
+ return results
0 commit comments