Skip to content

Commit

Permalink
Put dealloca on timeline
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Feb 19, 2025
1 parent 66f2e05 commit 91b8a3b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
5 changes: 1 addition & 4 deletions shortfin/src/shortfin/array/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,19 @@ storage storage::allocate_device(ScopedDevice &device,
IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, allocation_size,
buffer.for_output()));
SHORTFIN_SCHED_LOG(
"storage::allocate_device(device={}, affinity={:x}):[{}, Wait@{}:"
"storage::allocate_device(device={}, affinity={:x}):[{}, Wait@{}->"
"Signal:@{}] -> buffer={}",
static_cast<void *>(device.raw_device()->hal_device()),
device.affinity().queue_affinity(), static_cast<void *>(timeline_sem),
current_timepoint, signal_timepoint, static_cast<void *>(buffer.get()));

// Device allocations are always async.
// TODO: DO NOT SUBMIT: Enable async destruction.
TimelineResourceDestructor dtor =
TimelineResource::CreateAsyncBufferDestructor(device, buffer);
auto resource = device.fiber().NewTimelineResource(std::move(dtor));
resource->set_mutation_barrier(timeline_sem, signal_timepoint);
resource->use_barrier_insert(timeline_sem, signal_timepoint);
return storage(device, std::move(buffer), std::move(resource));
// return storage(device, std::move(buffer),
// device.fiber().NewTimelineResource());
}

storage storage::allocate_host(ScopedDevice &device,
Expand Down
29 changes: 22 additions & 7 deletions shortfin/src/shortfin/local/scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,40 @@ TimelineResource::~TimelineResource() {

TimelineResourceDestructor TimelineResource::CreateAsyncBufferDestructor(
ScopedDevice &scoped_device, iree::hal_buffer_ptr buffer) {
return [device = iree::hal_device_ptr::borrow_reference(
scoped_device.raw_device()->hal_device()),
affinity = scoped_device.affinity().queue_affinity(),
// The ScopedDevice doesn't lifetime extend the underlying hal device, so
// we must do that manually across the callback (and then release at the end).
iree_hal_device_retain(scoped_device.raw_device()->hal_device());
return [device_affinity = scoped_device.affinity(),
buffer = std::move(buffer)](TimelineResource &res) {
ScopedDevice scoped_device(*res.fiber(), device_affinity);
iree_hal_device_t *hal_device = scoped_device.raw_device()->hal_device();
auto queue_affinity = scoped_device.affinity().queue_affinity();
SHORTFIN_TRACE_SCOPE_NAMED("TimelineResource::AsyncBufferDestructor");
iree_hal_semaphore_list_t wait_semaphore_list = res.use_barrier();
iree_hal_semaphore_list_t signal_semaphore_list =
iree_hal_semaphore_list_empty();
// TODO: If desiring strict memory reclamation ordering, we want to queue
// order the successors to the deallocation.
auto fiber = res.fiber();
auto &account = fiber->scheduler().GetDefaultAccount(scoped_device);
iree_hal_semaphore_t *timeline_sem = account.timeline_sem();
uint64_t signal_timepoint = account.timeline_acquire_timepoint();
iree_hal_semaphore_list_t signal_semaphore_list{
.count = 1,
.semaphores = &timeline_sem,
.payload_values = &signal_timepoint,
};
if (SHORTFIN_SCHED_LOG_ENABLED) {
auto wait_sum = iree::DebugPrintSemaphoreList(wait_semaphore_list);
auto signal_sum = iree::DebugPrintSemaphoreList(signal_semaphore_list);
SHORTFIN_SCHED_LOG(
"async dealloca(device={}, affinity={:x}, buffer={}):[Wait:{}, "
"Signal:{}]",
static_cast<void *>(device.get()), affinity,
static_cast<void *>(hal_device), queue_affinity,
static_cast<void *>(buffer.get()), wait_sum, signal_sum);
}
SHORTFIN_THROW_IF_ERROR(iree_hal_device_queue_dealloca(
device, affinity, wait_semaphore_list, signal_semaphore_list, buffer));
hal_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
buffer));
iree_hal_device_release(hal_device);
};
}

Expand Down
2 changes: 2 additions & 0 deletions shortfin/src/shortfin/local/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class SHORTFIN_API TimelineResource {
if (--refcnt_ == 0) delete this;
}

Fiber *fiber() { return fiber_.get(); }

private:
TimelineResource(std::shared_ptr<Fiber> fiber, size_t semaphore_capacity,
TimelineResourceDestructor destructor);
Expand Down

0 comments on commit 91b8a3b

Please sign in to comment.