Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncBuffer update #130

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import logging
import uuid
from tqdm.asyncio import tqdm_asyncio
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager, nullcontext
from typing import Any, TypeVar, overload
Expand Down Expand Up @@ -193,12 +194,15 @@ async def _sample_trajectories_from_envs(
self.traj_buffer.clear()

traj_ids = [uuid.uuid4().hex for _ in range(len(environments))]
await asyncio.gather(

await tqdm_asyncio.gather(
*(
self._rollout(*args, max_steps=max_steps)
for args in zip(traj_ids, environments, strict=True)
)
),
desc="Sampling trajectories"
)

return [self.traj_buffer[traj_id] for traj_id in traj_ids]

async def _rollout(
Expand Down
245 changes: 245 additions & 0 deletions ldp/graph/async_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import operator
import time
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import nullcontext
Expand All @@ -18,6 +19,9 @@
"ldp.graph.async_torch requires PyTorch as a dependency. "
"Please run `pip install ldp[nn]`."
) from None


logger = logging.getLogger(__name__)

_TORCH_LOCK = asyncio.Lock()

Expand Down Expand Up @@ -146,6 +150,247 @@ async def _maybe_process_batch(self):
@abstractmethod
async def _batched_call(self, batch_kwargs: dict[str, Any]):
"""Logic to call the worker on a batch of inputs."""



class AsyncBufferedWorker2(ABC):
"""Abstract class for a worker that buffers inputs and processes them in batches."""

def __init__(
self,
batch_size: int,
max_wait_interval: float,
collate_fn: Callable = lambda x: x,
decollate_fn: Callable = lambda x: x,
):
"""Initialize.

Args:
batch_size: The target batch size to use when calling the module. As soon as
batch_size calls are made, a forward pass is executed.
max_wait_interval: The maximum time to wait for a batch to fill up before
executing the calls we have buffered.
collate_fn: A function to pre-process a list of inputs into a batch. Defaults to a
no-op.
decollate_fn: Kind of like the opposite of collate_fn. This function should take
the batched output and return an ordered list of outputs. Defaults to no-op.
"""
self.batch_size = batch_size
self.timeout = max_wait_interval
self.collate_fn = collate_fn
self.decollate_fn = decollate_fn

self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = []
self._result_buffer: dict[UUID, Any] = {}
self._lock = asyncio.Lock()
self._batch_ready_event = asyncio.Event()
self._processed_events = {}
self._counter = 0
self._events_count = {}

async def __call__(self, **kwargs):
request_id = uuid4()
request_ts = time.time()

async with self._lock:
self._processed_events[request_id] = asyncio.Event()
self._events_count[request_id] = self._counter
self._counter += 1
print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use logging over print

self._work_buffer.append((request_ts, request_id, kwargs))

# If we've reached batch size, we trigger the processing event immediately
if len(self._work_buffer) >= self.batch_size:
self._batch_ready_event.set()

try:
# Wait for either the batch to fill up or the timeout to expire
await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
pass
Comment on lines +206 to +210
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        async with asyncio.timeout(self.timeout):
            # Wait for either the batch to fill up or the timeout to expire
            self._batch_ready_event.wait()

Alternate way of doing this using built-in asyncio.timeout

I am not sure if you need to await the wait()


await self._maybe_process_batch()

await self._processed_events[request_id].wait()

async with self._lock:
print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use logging not print

self._events_count.pop(request_id)
self._processed_events.pop(request_id)
return self._result_buffer.pop(request_id)

async def _maybe_process_batch(self):
"""If the buffer is >= batch size or we have been waiting long enough, process the old batch.

If neither condition is met, do nothing.
"""
async with self._lock:
# If there's at least one request in the buffer, we can process it
if len(self._work_buffer) == 0:
return

self._work_buffer.sort(key=operator.itemgetter(0))

batch = self._work_buffer[: self.batch_size]
self._work_buffer = self._work_buffer[self.batch_size :]

if len(self._work_buffer) < self.batch_size:
self._batch_ready_event.clear()

# Construct the batch tensors
sample_kwargs = [x[2] for x in batch]
batch_kwargs = self.collate_fn(sample_kwargs)

print(f"starting to wait for batched call, counter: {self._counter}")
batched_results = await self._batched_call(batch_kwargs)
print(f"finished waiting for batched call, counter: {self._counter}")
request_ids = [x[1] for x in batch]
results = self.decollate_fn(batched_results)
async with self._lock:
print(f"updating result buffer, counter: {self._counter}")
self._result_buffer.update(zip(request_ids, results, strict=True))
for request_id in request_ids:
self._processed_events[request_id].set()

def _process_batch(self):
"""Processes the current batch."""


@abstractmethod
async def _batched_call(self, batch_kwargs: dict[str, Any]):
"""Logic to call the worker on a batch of inputs."""


class AsyncBufferedWorker2(ABC):
def __init__(
self,
batch_size: int,
max_wait_interval: float,
collate_fn: Callable = lambda x: x,
decollate_fn: Callable = lambda x: x,
):
self.batch_size = batch_size
self.timeout = max_wait_interval
self.collate_fn = collate_fn
self.decollate_fn = decollate_fn

self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = []
self._result_buffer: dict[UUID, Any] = {}
self._lock = asyncio.Lock()
self._new_data_event = asyncio.Event()

self._processed_events: dict[UUID, asyncio.Event] = {}
self._counter = 0
self._events_count: dict[UUID, int] = {} # Just for debugging and printing the order of requests
self._exception: Exception | None = None # Store exception from _batch_processor

# Start the background batch processing task
self._batch_processing_task = asyncio.create_task(self._batch_processor())
self._batch_processing_task.add_done_callback(self._handle_task_exception)

async def __call__(self, **kwargs):
request_id = uuid4()
request_ts = time.time()

async with self._lock:
if self._exception is not None:
# If an exception has occurred, raise it immediately
raise self._exception
self._processed_events[request_id] = asyncio.Event()
self._events_count[request_id] = self._counter
self._counter += 1
self._work_buffer.append((request_ts, request_id, kwargs))
if len(self._work_buffer) >= self.batch_size:
self._new_data_event.set() # Signal that new data has arrived

# Wait for the result to be processed or an exception to occur
await self._processed_events[request_id].wait()

async with self._lock:
self._events_count.pop(request_id)
self._processed_events.pop(request_id)
if self._exception is not None:
# If an exception occurred during processing, raise it here
raise self._exception
elif request_id in self._result_buffer:
return self._result_buffer.pop(request_id)
else:
# Should not happen, but handle just in case
raise RuntimeError("Result not available and no exception set.")

async def _batch_processor(self):
try:
while True:
try:
# Wait for new data or timeout
await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
pass

async with self._lock:
if len(self._work_buffer) == 0:
self._new_data_event.clear()
continue

# Sort the work buffer by timestamp to maintain order
self._work_buffer.sort(key=operator.itemgetter(0))

batch = self._work_buffer[:self.batch_size]
self._work_buffer = self._work_buffer[self.batch_size:]
if len(self._work_buffer) == 0:
self._new_data_event.clear()

# Process the batch outside the lock
sample_kwargs = [x[2] for x in batch]
batch_kwargs = self.collate_fn(sample_kwargs)
batched_results = await self._batched_call(batch_kwargs)
request_ids = [x[1] for x in batch]
results = self.decollate_fn(batched_results)
async with self._lock:
self._result_buffer.update(zip(request_ids, results))
for request_id in request_ids:
self._processed_events[request_id].set()

# Let other requests proceed as soon as their result is available
await asyncio.sleep(0)
except asyncio.CancelledError:
pass # Allow task to exit gracefully
except Exception as e:
logger.error(f"Exception in _batch_processor: {e}", exc_info=True)
# Store the exception
async with self._lock:
self._exception = e
# Notify all pending requests about the exception
for event in self._processed_events.values():
event.set()

def _handle_task_exception(self, task):
try:
task.result()
except asyncio.CancelledError:
# Task was cancelled, nothing to do
pass
except Exception as e:
# Already handled in _batch_processor
pass

async def close(self):
self._batch_processing_task.cancel()
try:
await self._batch_processing_task
except asyncio.CancelledError:
pass

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
await self.close()

@abstractmethod
async def _batched_call(self, batch_kwargs: dict[str, Any]):
"""Logic to call the worker on a batch of inputs."""
pass


class AsyncTorchModule(AsyncBufferedWorker):
Expand Down
Loading