Skip to content

Commit d04ae3b

Browse files
authored
env (interactive mode): timeout for blocking operations (#440)
The `open`ing of the communication named pipes and the `read` operations are blocking. This introduces a timeout. High level: - we associate a thread with `open`, which waits for a signal (Event), with a timeout. If it times out, the thread opens the associated pipe, which unblocks the main `open` - writes aren't expected to block - reads are. To address, we poll (`select.select`) before reading, with a timeout. The design treats the timeout as resetting for each blocking op. This is just to simplify the design - rather than checking passed time.
1 parent 69284f6 commit d04ae3b

File tree

2 files changed

+383
-11
lines changed

2 files changed

+383
-11
lines changed

compiler_opt/rl/env.py

Lines changed: 139 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
import dataclasses
1818
from enum import Enum
1919

20+
import logging
2021
import math
22+
import select
2123
import subprocess
2224
import abc
2325
import contextlib
2426
import io
2527
import os
28+
import threading
2629
from collections.abc import Callable, Generator
2730

2831
import numpy as np
@@ -218,6 +221,129 @@ def _reward_fn(a: float, b: float) -> float:
218221
return {key: _reward_fn(score_a[key], score_b[key]) for key in score_a}
219222

220223

224+
@contextlib.contextmanager
225+
def open_write_pipe(filename: str, *, timeout: float):
226+
"""Open the write pipe or timeout.
227+
228+
Assuming a fifo, the `open` will block until the other party (the process we
229+
communicate to) also opens the pipe. If that doesn't happen, we time out.
230+
Afterwards, `write` ops shouldn't block.
231+
"""
232+
opened = threading.Event()
233+
timed_out = threading.Event()
234+
235+
# start a thread that waits for `open` to unblock. If it doesn't, we open the
236+
# fifo ourselves just to unblock.
237+
def _timeout_thread():
238+
if opened.wait(timeout):
239+
logging.debug('[timeout thread] writer opened successfully')
240+
return
241+
timed_out.set()
242+
logging.debug('[timeout thread] writer failed to open')
243+
with open(filename, 'rb'):
244+
pass
245+
246+
waiter = threading.Thread(target=_timeout_thread)
247+
waiter.start()
248+
try:
249+
with io.BufferedWriter(io.FileIO(filename, 'wb')) as writer_pipe:
250+
if not timed_out.is_set():
251+
opened.set()
252+
yield writer_pipe
253+
finally:
254+
waiter.join()
255+
if timed_out.is_set():
256+
# it's possible that the timeout thread timed out but also the other
257+
# process finally opened the pipe and thus the `writer_pipe` is
258+
# functional, but at the end we still raise TimeoutError. We accept that
259+
# right now.
260+
raise TimeoutError('write pipe open')
261+
262+
263+
@contextlib.contextmanager
264+
def open_read_pipe(filename: str, *, timeout: float):
265+
"""Open the read pipe, with a timeout governing the open and each read.
266+
267+
Just like in the writer case, assuming we're opening a fifo pipe, the open
268+
operation will block until the other party opens the pipe. Then, because this
269+
is a reader, each read operation (and variations - readline, etc) can block,
270+
but no more than the provided timeout.
271+
"""
272+
273+
# wrap the underlying io.RawIOBase such that we poll before attempting to read
274+
def _wrap_raw_io(obj: io.RawIOBase):
275+
276+
def _get_polling_wrapper(wrapped_method):
277+
278+
def _replacement(*args, **kwargs):
279+
name = wrapped_method.__name__
280+
logging.debug('ReaderWithTimeout is asked to %s', name)
281+
(r, _, _) = select.select([obj], [], [], timeout)
282+
if r:
283+
logging.debug('ReaderWithTimeout %s should be unblocked', name)
284+
result = wrapped_method(*args, **kwargs)
285+
logging.debug('ReaderWithTimeout %s completed', name)
286+
return result
287+
logging.info('ReaderWithTimeout timed out waiting to %s', name)
288+
raise TimeoutError('timed out reading')
289+
290+
return _replacement
291+
292+
obj.read = _get_polling_wrapper(obj.read)
293+
obj.readline = _get_polling_wrapper(obj.readline)
294+
obj.readinto = _get_polling_wrapper(obj.readinto)
295+
obj.readall = _get_polling_wrapper(obj.readall)
296+
297+
return obj
298+
299+
opened = threading.Event()
300+
timed_out = threading.Event()
301+
302+
# same idea as in the writer case - unblock the `open`
303+
def _timeout_thread():
304+
if opened.wait(timeout):
305+
logging.debug('[timeout thread] reader opened successfully')
306+
return
307+
timed_out.set()
308+
logging.debug('[timeout thread] reader failed to open')
309+
with open(filename, 'wb'):
310+
pass
311+
logging.debug('[timeout thread] force-opened the reader')
312+
313+
waiter = threading.Thread(target=_timeout_thread)
314+
waiter.start()
315+
try:
316+
# we must wrap the *raw* stream! wrapping the buffered stream would be
317+
# incorrect because calls to `read` APIs shouldn't poll (they may just
318+
# return from the buffer).
319+
with io.BufferedReader(_wrap_raw_io(io.FileIO(filename,
320+
'rb'))) as reader_pipe:
321+
if not timed_out.is_set():
322+
opened.set()
323+
yield reader_pipe
324+
finally:
325+
waiter.join()
326+
if timed_out.is_set():
327+
# same as in the writer case - we could successfully keep reading but
328+
# still report a timeout at the end of this context.
329+
raise TimeoutError('read pipe open')
330+
331+
332+
@contextlib.contextmanager
333+
def interactive_session(*, reader_name: str, writer_name: str, timeout: float):
334+
"""Start an interactive session with the started process proc.
335+
336+
Blocking pipe operations - open and read - happen under a timeout.
337+
"""
338+
339+
try:
340+
with open_write_pipe(writer_name, timeout=timeout) as writer_pipe:
341+
with open_read_pipe(reader_name, timeout=timeout) as reader_pipe:
342+
yield (reader_pipe, writer_pipe)
343+
finally:
344+
pass
345+
346+
221347
@contextlib.contextmanager
222348
def clang_session(
223349
clang_path: str,
@@ -268,16 +394,19 @@ def _get_scores() -> dict[str, float]:
268394
cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
269395
try:
270396
if interactive:
271-
with io.BufferedWriter(io.FileIO(writer_name, 'wb')) as writer_pipe:
272-
with io.BufferedReader(io.FileIO(reader_name, 'rb')) as reader_pipe:
273-
yield InteractiveClang(
274-
proc,
275-
_get_scores,
276-
module.name,
277-
task_working_dir,
278-
reader_pipe,
279-
writer_pipe,
280-
)
397+
with interactive_session(
398+
writer_name=writer_name,
399+
reader_name=reader_name,
400+
timeout=compilation_runner.COMPILATION_TIMEOUT.value) as (
401+
reader_pipe, writer_pipe):
402+
yield InteractiveClang(
403+
proc,
404+
_get_scores,
405+
module.name,
406+
task_working_dir,
407+
reader_pipe,
408+
writer_pipe,
409+
)
281410
else:
282411
yield ClangProcess(
283412
proc,

0 commit comments

Comments
 (0)