Skip to content

Commit

Permalink
env (interactive mode): timeout for blocking operations (#440)
Browse files Browse the repository at this point in the history
The `open`ing of the communication named pipes and the
`read` operations are blocking. This introduces a timeout.

High level:
- we associate a thread with `open`, which waits for a signal
  (Event), with a timeout. If it times out, the thread opens
  the associated pipe, which unblocks the main `open`

- writes aren't expected to block
- reads are. To address, we poll (`select.select`) before reading,
  with a timeout.

The design treats the timeout as resetting for each blocking op.
This is just to simplify the design - rather than checking passed
time.
  • Loading branch information
mtrofin authored Feb 14, 2025
1 parent 69284f6 commit d04ae3b
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 11 deletions.
149 changes: 139 additions & 10 deletions compiler_opt/rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import dataclasses
from enum import Enum

import logging
import math
import select
import subprocess
import abc
import contextlib
import io
import os
import threading
from collections.abc import Callable, Generator

import numpy as np
Expand Down Expand Up @@ -218,6 +221,129 @@ def _reward_fn(a: float, b: float) -> float:
return {key: _reward_fn(score_a[key], score_b[key]) for key in score_a}


@contextlib.contextmanager
def open_write_pipe(filename: str, *, timeout: float):
"""Open the write pipe or timeout.
Assuming a fifo, the `open` will block until the other party (the process we
communicate to) also opens the pipe. If that doesn't happen, we time out.
Afterwards, `write` ops shouldn't block.
"""
opened = threading.Event()
timed_out = threading.Event()

# start a thread that waits for `open` to unblock. If it doesn't, we open the
# fifo ourselves just to unblock.
def _timeout_thread():
if opened.wait(timeout):
logging.debug('[timeout thread] writer opened successfully')
return
timed_out.set()
logging.debug('[timeout thread] writer failed to open')
with open(filename, 'rb'):
pass

waiter = threading.Thread(target=_timeout_thread)
waiter.start()
try:
with io.BufferedWriter(io.FileIO(filename, 'wb')) as writer_pipe:
if not timed_out.is_set():
opened.set()
yield writer_pipe
finally:
waiter.join()
if timed_out.is_set():
# it's possible that the timeout thread timed out but also the other
# process finally opened the pipe and thus the `writer_pipe` is
# functional, but at the end we still raise TimeoutError. We accept that
# right now.
raise TimeoutError('write pipe open')


@contextlib.contextmanager
def open_read_pipe(filename: str, *, timeout: float):
"""Open the read pipe, with a timeout governing the open and each read.
Just like in the writer case, assuming we're opening a fifo pipe, the open
operation will block until the other party opens the pipe. Then, because this
is a reader, each read operation (and variations - readline, etc) can block,
but no more than the provided timeout.
"""

# wrap the underlying io.RawIOBase such that we poll before attempting to read
def _wrap_raw_io(obj: io.RawIOBase):

def _get_polling_wrapper(wrapped_method):

def _replacement(*args, **kwargs):
name = wrapped_method.__name__
logging.debug('ReaderWithTimeout is asked to %s', name)
(r, _, _) = select.select([obj], [], [], timeout)
if r:
logging.debug('ReaderWithTimeout %s should be unblocked', name)
result = wrapped_method(*args, **kwargs)
logging.debug('ReaderWithTimeout %s completed', name)
return result
logging.info('ReaderWithTimeout timed out waiting to %s', name)
raise TimeoutError('timed out reading')

return _replacement

obj.read = _get_polling_wrapper(obj.read)
obj.readline = _get_polling_wrapper(obj.readline)
obj.readinto = _get_polling_wrapper(obj.readinto)
obj.readall = _get_polling_wrapper(obj.readall)

return obj

opened = threading.Event()
timed_out = threading.Event()

# same idea as in the writer case - unblock the `open`
def _timeout_thread():
if opened.wait(timeout):
logging.debug('[timeout thread] reader opened successfully')
return
timed_out.set()
logging.debug('[timeout thread] reader failed to open')
with open(filename, 'wb'):
pass
logging.debug('[timeout thread] force-opened the reader')

waiter = threading.Thread(target=_timeout_thread)
waiter.start()
try:
# we must wrap the *raw* stream! wrapping the buffered stream would be
# incorrect because calls to `read` APIs shouldn't poll (they may just
# return from the buffer).
with io.BufferedReader(_wrap_raw_io(io.FileIO(filename,
'rb'))) as reader_pipe:
if not timed_out.is_set():
opened.set()
yield reader_pipe
finally:
waiter.join()
if timed_out.is_set():
# same as in the writer case - we could successfully keep reading but
# still report a timeout at the end of this context.
raise TimeoutError('read pipe open')


@contextlib.contextmanager
def interactive_session(*, reader_name: str, writer_name: str, timeout: float):
"""Start an interactive session with the started process proc.
Blocking pipe operations - open and read - happen under a timeout.
"""

try:
with open_write_pipe(writer_name, timeout=timeout) as writer_pipe:
with open_read_pipe(reader_name, timeout=timeout) as reader_pipe:
yield (reader_pipe, writer_pipe)
finally:
pass


@contextlib.contextmanager
def clang_session(
clang_path: str,
Expand Down Expand Up @@ -268,16 +394,19 @@ def _get_scores() -> dict[str, float]:
cmdline, stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
try:
if interactive:
with io.BufferedWriter(io.FileIO(writer_name, 'wb')) as writer_pipe:
with io.BufferedReader(io.FileIO(reader_name, 'rb')) as reader_pipe:
yield InteractiveClang(
proc,
_get_scores,
module.name,
task_working_dir,
reader_pipe,
writer_pipe,
)
with interactive_session(
writer_name=writer_name,
reader_name=reader_name,
timeout=compilation_runner.COMPILATION_TIMEOUT.value) as (
reader_pipe, writer_pipe):
yield InteractiveClang(
proc,
_get_scores,
module.name,
task_working_dir,
reader_pipe,
writer_pipe,
)
else:
yield ClangProcess(
proc,
Expand Down
Loading

0 comments on commit d04ae3b

Please sign in to comment.