Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shortfin] Implement async alloc/dealloc of buffers. #507

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 115 additions & 61 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,32 +316,97 @@ local::ProgramInvocation::Future PyFunctionCall(
return local::ProgramInvocation::Invoke(std::move(inv));
}

py::object PyRehydrateRef(local::ProgramInvocation *inv,
iree::vm_opaque_ref ref) {
auto type = ref.get()->type;
// Note that these accessors are dangerous as they assert/abort if
// process-wide registration is not done properly. We assume here that
// since we got a ref out that the basics are set up soundly, but if actually
// doing this on user/dynamic types, we would want to be more defensive.
// TODO: Don't just do a linear scan if we have more than a couple.
// TODO: Find a reliable way to statically cache the type id.
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
array::device_array>() == type) {
// device_array
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::device_array>(
inv, std::move(ref)));
} else if (local::ProgramInvocationMarshalableFactory::
invocation_marshalable_type<array::storage>() == type) {
// storage
return py::cast(
local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::storage>(inv, std::move(ref)));
// Wraps a ProgramInvocation::Ptr representing a completed (awaited) invocation.
// Holds some additional accounting for marshaling results back to Python.
class PyProgramInvocation {
public:
PyProgramInvocation(local::ProgramInvocation::Ptr inv)
: inv_(std::move(inv)) {}
PyProgramInvocation(const PyProgramInvocation &) = delete;
PyProgramInvocation(PyProgramInvocation &&other)
: inv_(std::move(other.inv_)),
cached_results_(std::move(other.cached_results_)),
results_failure_(other.results_failure_) {}

// Fields that can be bound.
bool assume_no_alias = true;
static std::optional<bool> global_assume_no_alias;

void CheckValid() {
if (!inv_) throw std::invalid_argument("Deallocated invocation");
}
throw std::invalid_argument(
fmt::format("Cannot marshal ref type {} to Python",
to_string_view(iree_vm_ref_type_name(type))));
}
local::ProgramInvocation::Ptr &inv() { return inv_; }

py::object results() {
if (results_failure_) {
throw std::logic_error("Prior attempt to marshal IREE results failed");
}
if (cached_results_) {
return cached_results_;
}

// Cache results.
CheckValid();
results_failure_ = true;

local::CoarseInvocationTimelineImporter::Options options;
options.assume_no_alias = assume_no_alias;
if (global_assume_no_alias) {
options.assume_no_alias = *global_assume_no_alias;
}
local::CoarseInvocationTimelineImporter timeline_importer(inv().get(),
options);
size_t size = inv_->results_size();
py::object tp = py::steal(PyTuple_New(size));
for (size_t i = 0; i < size; ++i) {
iree::vm_opaque_ref ref = inv_->result_ref(i);
if (!ref) {
throw new std::logic_error("Program returned unsupported Python type");
}
py::object item = RehydrateRef(std::move(ref), &timeline_importer);
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
}

cached_results_ = std::move(tp);
results_failure_ = false;
return cached_results_;
}

private:
py::object RehydrateRef(
iree::vm_opaque_ref ref,
local::CoarseInvocationTimelineImporter *timeline_importer) {
auto type = ref.get()->type;
// Note that these accessors are dangerous as they assert/abort if
// process-wide registration is not done properly. We assume here that
// since we got a ref out that the basics are set up soundly, but if
// actually doing this on user/dynamic types, we would want to be more
// defensive.
// TODO: Don't just do a linear scan if we have more than a couple.
// TODO: Find a reliable way to statically cache the type id.
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
array::device_array>() == type) {
// device_array
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::device_array>(
inv().get(), timeline_importer, std::move(ref)));
} else if (local::ProgramInvocationMarshalableFactory::
invocation_marshalable_type<array::storage>() == type) {
// storage
return py::cast(local::ProgramInvocationMarshalableFactory::
CreateFromInvocationResultRef<array::storage>(
inv().get(), timeline_importer, std::move(ref)));
}
throw std::invalid_argument(
fmt::format("Cannot marshal ref type {} to Python",
to_string_view(iree_vm_ref_type_name(type))));
}

local::ProgramInvocation::Ptr inv_;
py::object cached_results_;
bool results_failure_ = false;
};
std::optional<bool> PyProgramInvocation::global_assume_no_alias;

py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,
py::object coro) {
Expand Down Expand Up @@ -743,56 +808,45 @@ void BindLocal(py::module_ &m) {
return local::ProgramModule::ParameterProvider(system, c_params);
},
py::arg("system"), py::arg("params"));
py::class_<local::ProgramInvocation::Ptr>(m, "ProgramInvocation")
py::class_<PyProgramInvocation>(m, "ProgramInvocation")
.def_rw("assume_no_alias", &PyProgramInvocation::assume_no_alias,
"Assumes that no results alias inputs or other buffers")
.def_rw_static(
"global_assume_no_alias",
&PyProgramInvocation::global_assume_no_alias,
"Globally changes the assume_no_alias flag for all invocations")
.def("invoke",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
return local::ProgramInvocation::Invoke(std::move(self));
[](PyProgramInvocation &self) {
self.CheckValid();
return local::ProgramInvocation::Invoke(std::move(self.inv()));
})
.def("add_arg",
[](local::ProgramInvocation::Ptr &self, py::handle arg) {
if (!self) throw std::invalid_argument("Deallocated invocation");
py::capsule inv_capsule(self.get());
[](PyProgramInvocation &self, py::handle arg) {
self.CheckValid();
py::capsule inv_capsule(&self.inv());
PyAddProgramInvocationArg(inv_capsule, arg);
})
.def("__iter__",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
size_t size = self->results_size();
py::object tp = py::steal(PyTuple_New(size));
for (size_t i = 0; i < size; ++i) {
iree::vm_opaque_ref ref = self->result_ref(i);
if (!ref) {
throw new std::logic_error(
"Program returned unsupported Python type");
}
py::object item = PyRehydrateRef(self.get(), std::move(ref));
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
}
return tp.attr("__iter__")();
[](PyProgramInvocation &self) {
return self.results().attr("__iter__")();
})
.def(
"__len__",
[](local::ProgramInvocation::Ptr &self) {
if (!self) throw std::invalid_argument("Deallocated invocation");
return self->results_size();
[](PyProgramInvocation &self) {
self.CheckValid();
return self.inv()->results_size();
},
"The number of results in this invocation")
.def(
"__getitem__",
[](local::ProgramInvocation::Ptr &self, iree_host_size_t i) {
if (!self) throw std::invalid_argument("Deallocated invocation");
iree::vm_opaque_ref ref = self->result_ref(i);
if (!ref) {
throw new std::logic_error(
"Program returned unsupported Python type");
}
return PyRehydrateRef(self.get(), std::move(ref));
[](PyProgramInvocation &self, iree_host_size_t i) {
self.CheckValid();
return self.results().attr("__getitem__")(i);
},
"Gets the i'th result")
.def("__repr__", [](local::ProgramInvocation::Ptr &self) {
if (!self) return std::string("ProgramInvocation(INVALID)");
return self->to_s();
.def("__repr__", [](PyProgramInvocation &self) {
if (!self.inv()) return std::string("ProgramInvocation(INVALID)");
return self.inv()->to_s();
});

py::class_<local::BaseProgramParameters>(m, "BaseProgramParameters");
Expand Down Expand Up @@ -1207,7 +1261,7 @@ void BindLocal(py::module_ &m) {
// expensive in the C++ API: essentially, ProgramInvocations flow
// through the system precisely one way. As a low level facility, this
// is deemed acceptable.
return py::cast(std::move(result));
return py::cast(PyProgramInvocation(std::move(result)));
});
py::class_<local::MessageFuture, local::Future>(m, "MessageFuture")
.def("result", [](local::MessageFuture &self) {
Expand Down
10 changes: 6 additions & 4 deletions shortfin/src/shortfin/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void device_array::AddAsInvocationArgument(

iree::vm_opaque_ref ref;
*(&ref) = iree_hal_buffer_view_move_ref(buffer_view);
inv->AddArg(std::move(ref));
inv->AddArg(std::move(ref), storage().timeline_resource_.get());

storage().AddInvocationArgBarrier(inv, barrier);
}
Expand All @@ -119,16 +119,18 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() {
}

device_array device_array::CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref) {
local::ProgramInvocation *inv,
local::CoarseInvocationTimelineImporter *timeline_importer,
iree::vm_opaque_ref ref) {
SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef");
// We don't retain the buffer view in the device array, so just deref it
// vs stealing the ref.
iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get());
iree::hal_buffer_ptr buffer =
iree::hal_buffer_ptr::borrow_reference(iree_hal_buffer_view_buffer(bv));

auto imported_storage =
storage::ImportInvocationResultStorage(inv, std::move(buffer));
auto imported_storage = storage::ImportInvocationResultStorage(
inv, timeline_importer, std::move(buffer));
std::span<const iree_hal_dim_t> shape(iree_hal_buffer_view_shape_dims(bv),
iree_hal_buffer_view_shape_rank(bv));
return device_array(
Expand Down
4 changes: 3 additions & 1 deletion shortfin/src/shortfin/array/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ class SHORTFIN_API device_array
void AddAsInvocationArgument(local::ProgramInvocation *inv,
local::ProgramResourceBarrier barrier) override;
static device_array CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref);
local::ProgramInvocation *inv,
local::CoarseInvocationTimelineImporter *timeline_importer,
iree::vm_opaque_ref ref);
static iree_vm_ref_type_t invocation_marshalable_type();
friend class shortfin::local::ProgramInvocationMarshalableFactory;
};
Expand Down
Loading
Loading