-
Notifications
You must be signed in to change notification settings - Fork 127
/
Copy pathparallel_processor.py
252 lines (217 loc) · 8.87 KB
/
parallel_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import logging
import os
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue, get_context
from multiprocessing.context import BaseContext
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import Synchronized as BaseValue
from queue import Empty
from typing import Any, Iterable, Optional, Type
# Single item should be processed in less than:
processing_timeout = 10 * 60 # seconds
max_internal_batch_size = 200
class QueueSignals(str, Enum):
stop = "stop"
confirm = "confirm"
error = "error"
class Worker:
@classmethod
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
raise NotImplementedError()
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
raise NotImplementedError()
def _worker(
worker_class: Type[Worker],
input_queue: Queue,
output_queue: Queue,
num_active_workers: BaseValue,
worker_id: int,
kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
A worker that pulls data pints off the input queue, and places the execution result on the output queue.
When there are no data pints left on the input queue, it decrements
num_active_workers to signal completion.
"""
if kwargs is None:
kwargs = {}
logging.info(
f"Reader worker: {worker_id} PID: {os.getpid()} Device: {kwargs.get('device_id', 'CPU')}"
)
try:
worker = worker_class.start(**kwargs)
# Keep going until you get an item that's None.
def input_queue_iterable() -> Iterable[Any]:
while True:
item = input_queue.get()
if item == QueueSignals.stop:
break
yield item
for processed_item in worker.process(input_queue_iterable()):
output_queue.put(processed_item)
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
output_queue.put(QueueSignals.error)
finally:
# It's important that we close and join the queue here before
# decrementing num_active_workers. Otherwise our parent may join us
# before the queue's feeder thread has passed all buffered items to
# the underlying pipe resulting in a deadlock.
#
# See:
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
input_queue.close()
output_queue.close()
input_queue.join_thread()
output_queue.join_thread()
with num_active_workers.get_lock():
num_active_workers.value -= 1
logging.info(f"Reader worker {worker_id} finished")
class ParallelWorkerPool:
def __init__(
self,
num_workers: int,
worker: Type[Worker],
start_method: Optional[str] = None,
device_ids: Optional[list[int]] = None,
cuda: bool = False,
):
self.worker_class = worker
self.num_workers = num_workers
self.input_queue: Optional[Queue] = None
self.output_queue: Optional[Queue] = None
self.ctx: BaseContext = get_context(start_method)
self.processes: list[BaseProcess] = []
self.queue_size = self.num_workers * max_internal_batch_size
self.emergency_shutdown = False
self.device_ids = device_ids
self.cuda = cuda
self.num_active_workers: Optional[BaseValue] = None
def start(self, **kwargs: Any) -> None:
self.input_queue = self.ctx.Queue(self.queue_size)
self.output_queue = self.ctx.Queue(self.queue_size)
ctx_value = self.ctx.Value("i", self.num_workers)
assert isinstance(ctx_value, BaseValue)
self.num_active_workers = ctx_value
for worker_id in range(0, self.num_workers):
worker_kwargs = deepcopy(kwargs)
if self.device_ids:
device_id = self.device_ids[worker_id % len(self.device_ids)]
worker_kwargs["device_id"] = device_id
worker_kwargs["cuda"] = self.cuda
assert hasattr(self.ctx, "Process")
process = self.ctx.Process(
target=_worker,
args=(
self.worker_class,
self.input_queue,
self.output_queue,
self.num_active_workers,
worker_id,
worker_kwargs,
),
)
process.start()
self.processes.append(process)
def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
buffer: defaultdict[int, Any] = defaultdict(Any) # type: ignore
next_expected = 0
for idx, item in self.semi_ordered_map(stream, *args, **kwargs):
buffer[idx] = item
while next_expected in buffer:
yield buffer.pop(next_expected)
next_expected += 1
def semi_ordered_map(
self, stream: Iterable[Any], *args: Any, **kwargs: Any
) -> Iterable[tuple[int, Any]]:
try:
self.start(**kwargs)
assert self.input_queue is not None, "Input queue was not initialized"
assert self.output_queue is not None, "Output queue was not initialized"
pushed = 0
read = 0
for idx, item in enumerate(stream):
self.check_worker_health()
if pushed - read < self.queue_size:
try:
out_item = self.output_queue.get_nowait()
except Empty:
out_item = None
else:
try:
out_item = self.output_queue.get(timeout=processing_timeout)
except Empty as e:
self.join_or_terminate()
raise e
if out_item is not None:
if out_item == QueueSignals.error:
self.join_or_terminate()
raise RuntimeError("Thread unexpectedly terminated")
yield out_item
read += 1
self.input_queue.put((idx, item))
pushed += 1
for _ in range(self.num_workers):
self.input_queue.put(QueueSignals.stop)
while read < pushed:
self.check_worker_health()
out_item = self.output_queue.get(timeout=processing_timeout)
if out_item == QueueSignals.error:
self.join_or_terminate()
raise RuntimeError("Thread unexpectedly terminated")
yield out_item
read += 1
finally:
assert self.input_queue is not None, "Input queue is None"
assert self.output_queue is not None, "Output queue is None"
self.join()
self.input_queue.close()
self.output_queue.close()
if self.emergency_shutdown:
self.input_queue.cancel_join_thread()
self.output_queue.cancel_join_thread()
else:
self.input_queue.join_thread()
self.output_queue.join_thread()
def check_worker_health(self) -> None:
"""
Checks if any worker process has terminated unexpectedly
"""
for process in self.processes:
if not process.is_alive() and process.exitcode != 0:
self.emergency_shutdown = True
self.join_or_terminate()
raise RuntimeError(
f"Worker PID: {process.pid} terminated unexpectedly with code {process.exitcode}"
)
def join_or_terminate(self, timeout: Optional[int] = 1) -> None:
"""
Emergency shutdown
@param timeout:
@return:
"""
for process in self.processes:
process.join(timeout=timeout)
if process.is_alive():
process.terminate()
self.processes.clear()
def join(self) -> None:
for process in self.processes:
process.join()
self.processes.clear()
def __del__(self) -> None:
"""
Terminate processes if the user hasn't joined. This is necessary as
leaving stray processes running can corrupt shared state. In brief,
we've observed shared memory counters being reused (when the memory was
free from the perspective of the parent process) while the stray
workers still held a reference to them.
For a discussion of using destructors in Python in this manner, see
https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
"""
for process in self.processes:
if process.is_alive():
process.terminate()