Skip to content

Commit 08c69aa

Browse files
[libshortfin] Implement invocation. (#159)
1 parent 89cc4c5 commit 08c69aa

31 files changed

+1505
-412
lines changed

.github/workflows/ci_linux_x64-libshortfin.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ jobs:
8383
- name: Install Python packages
8484
# TODO: Switch to `pip install -r requirements.txt -e libshortfin/`.
8585
run: |
86-
pip install nanobind
8786
pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt
87+
pip freeze
8888
8989
- name: Build libshortfin (full)
9090
run: |

.github/workflows/ci_linux_x64_asan-libshortfin.yml

+1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ jobs:
124124
run: |
125125
eval "$(pyenv init -)"
126126
pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt
127+
pip freeze
127128
128129
- name: Save Python dependencies cache
129130
if: steps.cache-python-deps-restore.outputs.cache-hit != 'true'

libshortfin/bindings/python/array_binding.cc

+84-42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@ using namespace shortfin::array;
1313
namespace shortfin::python {
1414

1515
namespace {
16+
static const char DOCSTRING_ARRAY_COPY_FROM[] =
17+
R"(Copy contents from a source array to this array.
18+
19+
Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
20+
)";
21+
22+
static const char DOCSTRING_ARRAY_COPY_TO[] =
23+
R"(Copy contents this array to a destination array.
24+
25+
Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
26+
)";
27+
28+
static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value.
29+
30+
Equivalent to `array.storage.fill(pattern)`.
31+
)";
32+
1633
static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents.
1734
1835
Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`.
@@ -28,6 +45,23 @@ As with `map`, this will only work on buffers that are host visible, which
2845
includes all host buffers and device buffers created with the necessary access.
2946
)";
3047

48+
static const char DOCSTRING_STORAGE_COPY_FROM[] =
49+
R"(Copy contents from a source storage to this array.
50+
51+
This operation executes asynchronously and the effect will only be visible
52+
once the execution scope has been synced to the point of mutation.
53+
)";
54+
55+
static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value.
56+
57+
Takes as argument any value that can be interpreted as a buffer with the Python
58+
buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly
59+
with the pattern.
60+
61+
This operation executes asynchronously and the effect will only be visible
62+
once the execution scope has been synced to the point of mutation.
63+
)";
64+
3165
static const char DOCSTRING_STORAGE_MAP[] =
3266
R"(Create a mapping of the buffer contents in host memory.
3367
@@ -72,58 +106,47 @@ void BindArray(py::module_ &m) {
72106
.def(py::self == py::self)
73107
.def("__repr__", &DType::name);
74108

75-
m.attr("opaque8") = DType::opaque8();
76-
m.attr("opaque16") = DType::opaque16();
77-
m.attr("opaque32") = DType::opaque32();
78-
m.attr("opaque64") = DType::opaque64();
79-
m.attr("bool8") = DType::bool8();
80-
m.attr("int4") = DType::int4();
81-
m.attr("sint4") = DType::sint4();
82-
m.attr("uint4") = DType::uint4();
83-
m.attr("int8") = DType::int8();
84-
m.attr("sint8") = DType::sint8();
85-
m.attr("uint8") = DType::uint8();
86-
m.attr("int16") = DType::int16();
87-
m.attr("sint16") = DType::sint16();
88-
m.attr("uint16") = DType::uint16();
89-
m.attr("int32") = DType::int32();
90-
m.attr("sint32") = DType::sint32();
91-
m.attr("uint32") = DType::uint32();
92-
m.attr("int64") = DType::int64();
93-
m.attr("sint64") = DType::sint64();
94-
m.attr("uint64") = DType::uint64();
95-
m.attr("float16") = DType::float16();
96-
m.attr("float32") = DType::float32();
97-
m.attr("float64") = DType::float64();
98-
m.attr("bfloat16") = DType::bfloat16();
99-
m.attr("complex64") = DType::complex64();
100-
m.attr("complex128") = DType::complex128();
109+
#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident();
110+
#include "shortfin/array/dtypes.inl"
111+
#undef SHORTFIN_DTYPE_HANDLE
101112

102113
// storage
103114
py::class_<storage>(m, "storage")
115+
.def("__sfinv_marshal__",
116+
[](device_array *self, py::capsule inv_capsule, int barrier) {
117+
auto *inv =
118+
static_cast<local::ProgramInvocation *>(inv_capsule.data());
119+
static_cast<local::ProgramInvocationMarshalable *>(self)
120+
->AddAsInvocationArgument(
121+
inv, static_cast<local::ProgramResourceBarrier>(barrier));
122+
})
104123
.def_static(
105124
"allocate_host",
106125
[](local::ScopedDevice &device, iree_device_size_t allocation_size) {
107-
return storage::AllocateHost(device, allocation_size);
126+
return storage::allocate_host(device, allocation_size);
108127
},
109128
py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
110129
.def_static(
111130
"allocate_device",
112131
[](local::ScopedDevice &device, iree_device_size_t allocation_size) {
113-
return storage::AllocateDevice(device, allocation_size);
132+
return storage::allocate_device(device, allocation_size);
114133
},
115134
py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
116-
.def("fill",
117-
[](storage &self, py::handle buffer) {
118-
Py_buffer py_view;
119-
int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
120-
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
121-
throw py::python_error();
122-
}
123-
PyBufferReleaser py_view_releaser(py_view);
124-
self.Fill(py_view.buf, py_view.len);
125-
})
126-
.def("copy_from", [](storage &self, storage &src) { self.CopyFrom(src); })
135+
.def(
136+
"fill",
137+
[](storage &self, py::handle buffer) {
138+
Py_buffer py_view;
139+
int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
140+
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
141+
throw py::python_error();
142+
}
143+
PyBufferReleaser py_view_releaser(py_view);
144+
self.fill(py_view.buf, py_view.len);
145+
},
146+
py::arg("pattern"), DOCSTRING_STORAGE_FILL)
147+
.def(
148+
"copy_from", [](storage &self, storage &src) { self.copy_from(src); },
149+
py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM)
127150
.def(
128151
"map",
129152
[](storage &self, bool read, bool write, bool discard) {
@@ -137,7 +160,7 @@ void BindArray(py::module_ &m) {
137160
}
138161
mapping *cpp_mapping = nullptr;
139162
py::object py_mapping = CreateMappingObject(&cpp_mapping);
140-
self.MapExplicit(
163+
self.map_explicit(
141164
*cpp_mapping,
142165
static_cast<iree_hal_memory_access_bits_t>(access));
143166
return py_mapping;
@@ -154,12 +177,12 @@ void BindArray(py::module_ &m) {
154177
[](storage &self) {
155178
mapping *cpp_mapping = nullptr;
156179
py::object py_mapping = CreateMappingObject(&cpp_mapping);
157-
*cpp_mapping = self.MapRead();
180+
*cpp_mapping = self.map_read();
158181
return py_mapping;
159182
},
160183
[](storage &self, py::handle buffer_obj) {
161184
PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE);
162-
auto dest_data = self.MapWriteDiscard();
185+
auto dest_data = self.map_write_discard();
163186
if (src_info.view().len > dest_data.size()) {
164187
throw std::invalid_argument(
165188
fmt::format("Cannot write {} bytes into buffer of {} bytes",
@@ -219,6 +242,14 @@ void BindArray(py::module_ &m) {
219242
py_type, /*keep_alive=*/device.scope(),
220243
device_array::for_device(device, shape, dtype));
221244
})
245+
.def("__sfinv_marshal__",
246+
[](device_array *self, py::capsule inv_capsule, int barrier) {
247+
auto *inv =
248+
static_cast<local::ProgramInvocation *>(inv_capsule.data());
249+
static_cast<local::ProgramInvocationMarshalable *>(self)
250+
->AddAsInvocationArgument(
251+
inv, static_cast<local::ProgramResourceBarrier>(barrier));
252+
})
222253
.def_static("for_device",
223254
[](local::ScopedDevice &device, std::span<const size_t> shape,
224255
DType dtype) {
@@ -243,6 +274,17 @@ void BindArray(py::module_ &m) {
243274
py::rv_policy::reference_internal)
244275
.def_prop_ro("storage", &device_array::storage,
245276
py::rv_policy::reference_internal)
277+
278+
.def(
279+
"fill",
280+
[](py::handle_t<device_array> self, py::handle buffer) {
281+
self.attr("storage").attr("fill")(buffer);
282+
},
283+
py::arg("pattern"), DOCSTRING_ARRAY_FILL)
284+
.def("copy_from", &device_array::copy_from, py::arg("source_array"),
285+
DOCSTRING_ARRAY_COPY_FROM)
286+
.def("copy_to", &device_array::copy_to, py::arg("dest_array"),
287+
DOCSTRING_ARRAY_COPY_TO)
246288
.def("__repr__", &device_array::to_s);
247289
}
248290

0 commit comments

Comments
 (0)