diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index 06535889..61ebdf0a 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -2,6 +2,7 @@ import itertools import logging import uuid +from tqdm.asyncio import tqdm_asyncio from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager, nullcontext from typing import Any, TypeVar, overload @@ -193,12 +194,15 @@ async def _sample_trajectories_from_envs( self.traj_buffer.clear() traj_ids = [uuid.uuid4().hex for _ in range(len(environments))] - await asyncio.gather( + + await tqdm_asyncio.gather( *( self._rollout(*args, max_steps=max_steps) for args in zip(traj_ids, environments, strict=True) - ) + ), + desc="Sampling trajectories" ) + return [self.traj_buffer[traj_id] for traj_id in traj_ids] async def _rollout( diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index 9b74a646..871feb0b 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -3,6 +3,7 @@ import asyncio import operator import time +import logging from abc import ABC, abstractmethod from collections.abc import Callable from contextlib import nullcontext @@ -18,6 +19,9 @@ "ldp.graph.async_torch requires PyTorch as a dependency. " "Please run `pip install ldp[nn]`." ) from None + + +logger = logging.getLogger(__name__) _TORCH_LOCK = asyncio.Lock() @@ -146,6 +150,247 @@ async def _maybe_process_batch(self): @abstractmethod async def _batched_call(self, batch_kwargs: dict[str, Any]): """Logic to call the worker on a batch of inputs.""" + + + +class AsyncBufferedWorker2(ABC): + """Abstract class for a worker that buffers inputs and processes them in batches.""" + + def __init__( + self, + batch_size: int, + max_wait_interval: float, + collate_fn: Callable = lambda x: x, + decollate_fn: Callable = lambda x: x, + ): + """Initialize. + + Args: + batch_size: The target batch size to use when calling the module. As soon as + batch_size calls are made, a forward pass is executed. + max_wait_interval: The maximum time to wait for a batch to fill up before + executing the calls we have buffered. + collate_fn: A function to pre-process a list of inputs into a batch. Defaults to a + no-op. + decollate_fn: Kind of like the opposite of collate_fn. This function should take + the batched output and return an ordered list of outputs. Defaults to no-op. + """ + self.batch_size = batch_size + self.timeout = max_wait_interval + self.collate_fn = collate_fn + self.decollate_fn = decollate_fn + + self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] + self._result_buffer: dict[UUID, Any] = {} + self._lock = asyncio.Lock() + self._batch_ready_event = asyncio.Event() + self._processed_events = {} + self._counter = 0 + self._events_count = {} + + async def __call__(self, **kwargs): + request_id = uuid4() + request_ts = time.time() + + async with self._lock: + self._processed_events[request_id] = asyncio.Event() + self._events_count[request_id] = self._counter + self._counter += 1 + print(f"Started Request ID: {request_id}, Counter: {self._events_count[request_id]}") + self._work_buffer.append((request_ts, request_id, kwargs)) + + # If we've reached batch size, we trigger the processing event immediately + if len(self._work_buffer) >= self.batch_size: + self._batch_ready_event.set() + + try: + # Wait for either the batch to fill up or the timeout to expire + await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + pass + + await self._maybe_process_batch() + + await self._processed_events[request_id].wait() + + async with self._lock: + print(f"Finished Request ID: {request_id}, Counter: {self._events_count[request_id]}") + self._events_count.pop(request_id) + self._processed_events.pop(request_id) + return self._result_buffer.pop(request_id) + + async def _maybe_process_batch(self): + """If the buffer is >= batch size or we have been waiting long enough, process the old batch. + + If neither condition is met, do nothing. + """ + async with self._lock: + # If there's at least one request in the buffer, we can process it + if len(self._work_buffer) == 0: + return + + self._work_buffer.sort(key=operator.itemgetter(0)) + + batch = self._work_buffer[: self.batch_size] + self._work_buffer = self._work_buffer[self.batch_size :] + + if len(self._work_buffer) < self.batch_size: + self._batch_ready_event.clear() + + # Construct the batch tensors + sample_kwargs = [x[2] for x in batch] + batch_kwargs = self.collate_fn(sample_kwargs) + + print(f"starting to wait for batched call, counter: {self._counter}") + batched_results = await self._batched_call(batch_kwargs) + print(f"finished waiting for batched call, counter: {self._counter}") + request_ids = [x[1] for x in batch] + results = self.decollate_fn(batched_results) + async with self._lock: + print(f"updating result buffer, counter: {self._counter}") + self._result_buffer.update(zip(request_ids, results, strict=True)) + for request_id in request_ids: + self._processed_events[request_id].set() + + def _process_batch(self): + """Processes the current batch.""" + + + @abstractmethod + async def _batched_call(self, batch_kwargs: dict[str, Any]): + """Logic to call the worker on a batch of inputs.""" + + +class AsyncBufferedWorker2(ABC): + def __init__( + self, + batch_size: int, + max_wait_interval: float, + collate_fn: Callable = lambda x: x, + decollate_fn: Callable = lambda x: x, + ): + self.batch_size = batch_size + self.timeout = max_wait_interval + self.collate_fn = collate_fn + self.decollate_fn = decollate_fn + + self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] + self._result_buffer: dict[UUID, Any] = {} + self._lock = asyncio.Lock() + self._new_data_event = asyncio.Event() + + self._processed_events: dict[UUID, asyncio.Event] = {} + self._counter = 0 + self._events_count: dict[UUID, int] = {} # Just for debugging and printing the order of requests + self._exception: Exception | None = None # Store exception from _batch_processor + + # Start the background batch processing task + self._batch_processing_task = asyncio.create_task(self._batch_processor()) + self._batch_processing_task.add_done_callback(self._handle_task_exception) + + async def __call__(self, **kwargs): + request_id = uuid4() + request_ts = time.time() + + async with self._lock: + if self._exception is not None: + # If an exception has occurred, raise it immediately + raise self._exception + self._processed_events[request_id] = asyncio.Event() + self._events_count[request_id] = self._counter + self._counter += 1 + self._work_buffer.append((request_ts, request_id, kwargs)) + if len(self._work_buffer) >= self.batch_size: + self._new_data_event.set() # Signal that new data has arrived + + # Wait for the result to be processed or an exception to occur + await self._processed_events[request_id].wait() + + async with self._lock: + self._events_count.pop(request_id) + self._processed_events.pop(request_id) + if self._exception is not None: + # If an exception occurred during processing, raise it here + raise self._exception + elif request_id in self._result_buffer: + return self._result_buffer.pop(request_id) + else: + # Should not happen, but handle just in case + raise RuntimeError("Result not available and no exception set.") + + async def _batch_processor(self): + try: + while True: + try: + # Wait for new data or timeout + await asyncio.wait_for(self._new_data_event.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + pass + + async with self._lock: + if len(self._work_buffer) == 0: + self._new_data_event.clear() + continue + + # Sort the work buffer by timestamp to maintain order + self._work_buffer.sort(key=operator.itemgetter(0)) + + batch = self._work_buffer[:self.batch_size] + self._work_buffer = self._work_buffer[self.batch_size:] + if len(self._work_buffer) == 0: + self._new_data_event.clear() + + # Process the batch outside the lock + sample_kwargs = [x[2] for x in batch] + batch_kwargs = self.collate_fn(sample_kwargs) + batched_results = await self._batched_call(batch_kwargs) + request_ids = [x[1] for x in batch] + results = self.decollate_fn(batched_results) + async with self._lock: + self._result_buffer.update(zip(request_ids, results)) + for request_id in request_ids: + self._processed_events[request_id].set() + + # Let other requests proceed as soon as their result is available + await asyncio.sleep(0) + except asyncio.CancelledError: + pass # Allow task to exit gracefully + except Exception as e: + logger.error(f"Exception in _batch_processor: {e}", exc_info=True) + # Store the exception + async with self._lock: + self._exception = e + # Notify all pending requests about the exception + for event in self._processed_events.values(): + event.set() + + def _handle_task_exception(self, task): + try: + task.result() + except asyncio.CancelledError: + # Task was cancelled, nothing to do + pass + except Exception as e: + # Already handled in _batch_processor + pass + + async def close(self): + self._batch_processing_task.cancel() + try: + await self._batch_processing_task + except asyncio.CancelledError: + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + + @abstractmethod + async def _batched_call(self, batch_kwargs: dict[str, Any]): + """Logic to call the worker on a batch of inputs.""" + pass class AsyncTorchModule(AsyncBufferedWorker):