Skip to content

Commit 2ab329e

Browse files
authored
ProcessGroupBaby: use pipe for improved performance (#121)
1 parent 082753c commit 2ab329e

File tree

4 files changed

+152
-187
lines changed

4 files changed

+152
-187
lines changed

torchft/multiprocessing.py

+16-79
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,32 @@
11
import queue
22
import time
33
from datetime import timedelta
4+
from multiprocessing.connection import Connection
45
from typing import Union
56

67
import torch.multiprocessing as mp
78

89

9-
class _MonitoredQueue:
10-
def __init__(
11-
self,
12-
p: mp.Process,
13-
q: mp.Queue,
14-
poll_interval: timedelta = timedelta(seconds=1),
15-
) -> None:
16-
"""
17-
Args:
18-
p: process to monitor
19-
q: queue to monitor
20-
poll_interval: interval to poll the Process health when calling get/put
21-
"""
22-
self._p = p
23-
self._q = q
24-
self._poll_interval_s: float = poll_interval.total_seconds()
10+
class _MonitoredPipe:
11+
def __init__(self, pipe: "Connection[object, object]") -> None:
12+
self._pipe = pipe
2513

26-
def get(self, timeout: Union[float, timedelta]) -> object:
27-
"""
28-
Get an item from the queue. If the process is not alive, raise RuntimeError.
29-
If the queue is empty, wait for up to timeout seconds for an item to be
30-
available. If no item is available after timeout seconds, raise TimeoutError.
31-
32-
Args:
33-
timeout: timeout in seconds
34-
"""
14+
def send(self, obj: object) -> None:
15+
self._pipe.send(obj)
3516

17+
def recv(self, timeout: Union[float, timedelta]) -> object:
3618
if isinstance(timeout, timedelta):
3719
timeout = timeout.total_seconds()
38-
39-
start = time.perf_counter()
40-
while True:
41-
try:
42-
v = self._q.get(timeout=self._poll_interval_s)
43-
break
44-
except queue.Empty:
45-
pass
46-
47-
elapsed = time.perf_counter() - start
48-
if elapsed > timeout:
49-
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")
50-
51-
# polling the process can be slow so we only do it every poll_interval
52-
if not self._p.is_alive():
53-
raise RuntimeError(f"process is not alive {self._p.exitcode}")
54-
55-
if isinstance(v, Exception):
56-
raise v
57-
return v
58-
59-
def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
60-
"""
61-
Put an item into the queue. If the process is not alive, raise RuntimeError.
62-
If the queue is full, wait for up to timeout seconds for an item to be
63-
available. If queue is full after timeout seconds, raise TimeoutError.
64-
65-
If an exception is put into the queue, it will be raised when calling get().
66-
67-
Args:
68-
obj: object to put into the queue
69-
timeout: timeout in seconds
70-
"""
71-
if isinstance(timeout, timedelta):
72-
timeout = timeout.total_seconds()
73-
74-
start = time.perf_counter()
75-
while True:
76-
try:
77-
self._q.put(obj, timeout=self._poll_interval_s)
78-
break
79-
except queue.Full:
80-
pass
81-
82-
elapsed = time.perf_counter() - start
83-
if elapsed > timeout:
84-
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")
85-
86-
# polling the process can be slow so we only do it every poll_interval
87-
if not self._p.is_alive():
88-
raise RuntimeError(f"process is not alive {self._p.exitcode}")
20+
if self._pipe.poll(timeout):
21+
out = self._pipe.recv()
22+
if isinstance(out, Exception):
23+
raise out
24+
return out
25+
else:
26+
raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds")
8927

9028
def close(self) -> None:
91-
self._q.close()
29+
self._pipe.close()
9230

9331
def closed(self) -> bool:
94-
# pyre-ignore[16]: no attribute _closed
95-
return self._q._closed
32+
return self._pipe.closed

torchft/multiprocessing_test.py

+29-22
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,55 @@
1+
from multiprocessing.connection import Connection
12
from unittest import TestCase
23

34
import torch.multiprocessing as mp
45

5-
from torchft.multiprocessing import _MonitoredQueue
6+
from torchft.multiprocessing import _MonitoredPipe
67

78

8-
def queue_get(q: mp.Queue) -> None:
9-
q.get()
9+
def pipe_get(q: "Connection[object, object]") -> None:
10+
q.recv()
1011

1112

12-
def queue_put(q: mp.Queue) -> None:
13-
q.put(1)
13+
def pipe_put(q: "Connection[object, object]") -> None:
14+
q.recv()
15+
q.send(1)
1416

1517

1618
class MultiprocessingTest(TestCase):
1719
def test_monitored_queue_put(self) -> None:
1820
ctx = mp.get_context("fork")
19-
q = ctx.Queue(maxsize=1)
20-
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
21+
local, remote = ctx.Pipe()
22+
p = ctx.Process(target=pipe_get, args=(remote,), daemon=True)
2123
p.start()
24+
del remote
2225

23-
mq = _MonitoredQueue(p, q)
24-
mq.put(1, timeout=10)
25-
mq.put(1, timeout=10)
26-
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
27-
mq.put(1, timeout=10)
28-
29-
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
30-
mq.put(1, timeout=0.0)
26+
mq = _MonitoredPipe(local)
27+
mq.send(1)
28+
with self.assertRaisesRegex(ConnectionResetError, "Connection reset by peer"):
29+
while True:
30+
mq.send(1)
3131

3232
mq.close()
33+
assert mq.closed()
3334

3435
def test_monitored_queue_get(self) -> None:
3536
ctx = mp.get_context("fork")
36-
q = ctx.Queue(maxsize=1)
37-
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
37+
local, remote = ctx.Pipe()
38+
p = ctx.Process(target=pipe_put, args=(remote,), daemon=True)
3839
p.start()
40+
del remote
3941

40-
mq = _MonitoredQueue(p, q)
41-
self.assertEqual(mq.get(timeout=10), 1)
42-
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
43-
mq.get(timeout=10)
42+
mq = _MonitoredPipe(local)
4443

4544
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
46-
mq.get(timeout=0.0)
45+
mq.recv(timeout=0.0)
46+
47+
# continue
48+
mq.send(1)
49+
50+
self.assertEqual(mq.recv(timeout=10), 1)
51+
with self.assertRaises(EOFError):
52+
mq.recv(timeout=10)
4753

4854
mq.close()
55+
assert mq.closed()

0 commit comments

Comments
 (0)