Skip to content

Commit b0bffda

Browse files
committed
[Bindings] Make copies to local host when map is unavailable.
1 parent 59d0fad commit b0bffda

File tree

5 files changed

+191
-4
lines changed

5 files changed

+191
-4
lines changed

runtime/bindings/python/hal.cc

+83
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ static const char kHalDeviceQueueExecute[] =
5757
signal_semaphores: Semaphores/Fence to signal.
5858
)";
5959

60+
static const char kHalDeviceQueueCopy[] =
61+
R"(Copy data from a source buffer to destination buffer.
62+
63+
Args:
64+
source_buffer: `HalBuffer` that holds src data.
65+
target_buffer: `HalBuffer` that will receive data.
66+
wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
67+
a HalFence. The allocation will be made once these semaphores are
68+
satisfied.
69+
signal_semaphores: Semaphores/Fence to signal.
70+
)";
71+
6072
static const char kHalFenceWait[] =
6173
R"(Waits until the fence is signalled or errored.
6274
@@ -524,6 +536,69 @@ void HalDevice::QueueExecute(py::handle command_buffers,
524536
"executing command buffers");
525537
}
526538

539+
void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
540+
py::handle wait_semaphores,
541+
py::handle signal_semaphores) {
542+
iree_hal_semaphore_list_t wait_list;
543+
iree_hal_semaphore_list_t signal_list;
544+
545+
// Wait list.
546+
if (py::isinstance<HalFence>(wait_semaphores)) {
547+
wait_list = iree_hal_fence_semaphore_list(
548+
py::cast<HalFence*>(wait_semaphores)->raw_ptr());
549+
} else {
550+
size_t wait_count = py::len(wait_semaphores);
551+
wait_list = {
552+
wait_count,
553+
/*semaphores=*/
554+
static_cast<iree_hal_semaphore_t**>(
555+
alloca(sizeof(iree_hal_semaphore_t*) * wait_count)),
556+
/*payload_values=*/
557+
static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)),
558+
};
559+
for (size_t i = 0; i < wait_count; ++i) {
560+
py::tuple pair = wait_semaphores[i];
561+
wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
562+
wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
563+
}
564+
}
565+
566+
// Signal list.
567+
if (py::isinstance<HalFence>(signal_semaphores)) {
568+
signal_list = iree_hal_fence_semaphore_list(
569+
py::cast<HalFence*>(signal_semaphores)->raw_ptr());
570+
} else {
571+
size_t signal_count = py::len(signal_semaphores);
572+
signal_list = {
573+
signal_count,
574+
/*semaphores=*/
575+
static_cast<iree_hal_semaphore_t**>(
576+
alloca(sizeof(iree_hal_semaphore_t*) * signal_count)),
577+
/*payload_values=*/
578+
static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)),
579+
};
580+
for (size_t i = 0; i < signal_count; ++i) {
581+
py::tuple pair = signal_semaphores[i];
582+
signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr();
583+
signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]);
584+
}
585+
}
586+
587+
// TODO: Accept params for src_offset and target_offset.
588+
iree_device_size_t source_length =
589+
iree_hal_buffer_byte_length(source_buffer.raw_ptr());
590+
if (source_length != iree_hal_buffer_byte_length(target_buffer.raw_ptr())) {
591+
throw std::invalid_argument(
592+
"Source and target buffer length must match and it does not. Please "
593+
"check allocations");
594+
}
595+
CheckApiStatus(iree_hal_device_queue_copy(
596+
raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
597+
signal_list, source_buffer.raw_ptr(), 0,
598+
target_buffer.raw_ptr(), 0, source_length),
599+
"Copying buffer on queue");
600+
}
601+
527602
//------------------------------------------------------------------------------
528603
// HalDriver
529604
//------------------------------------------------------------------------------
@@ -861,6 +936,9 @@ void SetupHalBindings(nanobind::module_ m) {
861936
.def("queue_execute", &HalDevice::QueueExecute,
862937
py::arg("command_buffers"), py::arg("wait_semaphores"),
863938
py::arg("signal_semaphores"), kHalDeviceQueueExecute)
939+
.def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"),
940+
py::arg("target_buffer"), py::arg("wait_semaphores"),
941+
py::arg("signal_semaphores"), kHalDeviceQueueCopy)
864942
.def("__repr__", [](HalDevice& self) {
865943
auto id_sv = iree_hal_device_id(self.raw_ptr());
866944
return std::string(id_sv.data, id_sv.size);
@@ -963,6 +1041,9 @@ void SetupHalBindings(nanobind::module_ m) {
9631041
py::class_<HalBuffer>(m, "HalBuffer")
9641042
.def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
9651043
py::arg("byte_length"))
1044+
.def("byte_length", &HalBuffer::byte_length)
1045+
.def("memory_type", &HalBuffer::memory_type)
1046+
.def("allowed_usage", &HalBuffer::allowed_usage)
9661047
.def("create_view", &HalBuffer::CreateView, py::arg("shape"),
9671048
py::arg("element_size"), py::keep_alive<0, 1>())
9681049
.def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>())
@@ -994,6 +1075,8 @@ void SetupHalBindings(nanobind::module_ m) {
9941075
py::arg("buffer"), py::arg("shape"), py::arg("element_type"));
9951076
hal_buffer_view
9961077
.def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>())
1078+
.def("get_buffer", HalBuffer::CreateFromBufferView,
1079+
py::keep_alive<0, 1>())
9971080
.def_prop_ro("shape",
9981081
[](HalBufferView& self) {
9991082
iree_host_size_t rank =

runtime/bindings/python/hal.h

+11
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
128128
py::handle signal_semaphores);
129129
void QueueExecute(py::handle command_buffers, py::handle wait_semaphores,
130130
py::handle signal_semaphores);
131+
void QueueCopy(HalBuffer& src_buffer, HalBuffer& dst_buffer,
132+
py::handle wait_semaphores, py::handle signal_semaphores);
131133
};
132134

133135
class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
@@ -176,6 +178,10 @@ class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
176178
return iree_hal_buffer_byte_length(raw_ptr());
177179
}
178180

181+
int memory_type() const { return iree_hal_buffer_memory_type(raw_ptr()); }
182+
183+
int allowed_usage() const { return iree_hal_buffer_allowed_usage(raw_ptr()); }
184+
179185
void FillZero(iree_device_size_t byte_offset,
180186
iree_device_size_t byte_length) {
181187
CheckApiStatus(
@@ -197,6 +203,11 @@ class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
197203
return HalBufferView::StealFromRawPtr(bv);
198204
}
199205

206+
static HalBuffer CreateFromBufferView(HalBufferView& bv) {
207+
return HalBuffer::BorrowFromRawPtr(
208+
iree_hal_buffer_view_buffer(bv.raw_ptr()));
209+
}
210+
200211
py::str Repr();
201212
};
202213

runtime/bindings/python/iree/runtime/array_interop.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
HalElementType,
1818
MappedMemory,
1919
MemoryType,
20+
HalFence,
2021
)
2122

2223
__all__ = [
@@ -106,6 +107,20 @@ def to_host(self) -> np.ndarray:
106107
self._transfer_to_host(False)
107108
return self._host_array
108109

110+
def _is_mappable(self) -> bool:
111+
buffer = self._buffer_view.get_buffer()
112+
if (
113+
buffer.memory_type() & int(MemoryType.HOST_VISIBLE)
114+
!= MemoryType.HOST_VISIBLE
115+
):
116+
return False
117+
if (
118+
buffer.allowed_usage() & int(BufferUsage.MAPPING_SCOPED)
119+
!= BufferUsage.MAPPING_SCOPED
120+
):
121+
return False
122+
return True
123+
109124
def _transfer_to_host(self, implicit):
110125
if self._host_array is not None:
111126
return
@@ -114,7 +129,10 @@ def _transfer_to_host(self, implicit):
114129
"DeviceArray cannot be implicitly transferred to the host: "
115130
"if necessary, do an explicit transfer via .to_host()"
116131
)
117-
self._mapped_memory, self._host_array = self._map_to_host()
132+
if self._is_mappable():
133+
self._mapped_memory, self._host_array = self._map_to_host()
134+
else:
135+
self._host_array = self._copy_to_host()
118136

119137
def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
120138
# TODO: When synchronization is enabled, need to block here.
@@ -129,6 +147,35 @@ def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
129147
host_array = host_array.astype(self._override_dtype)
130148
return mapped_memory, host_array
131149

150+
def _copy_to_host(self) -> np.ndarray:
151+
# TODO: When synchronization is enabled, need to block here.
152+
source_buffer = self._buffer_view.get_buffer()
153+
host_buffer = self._device.allocator.allocate_buffer(
154+
memory_type=(MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE),
155+
allowed_usage=(BufferUsage.TRANSFER_TARGET | BufferUsage.MAPPING_SCOPED),
156+
allocation_size=source_buffer.byte_length(),
157+
)
158+
# Copy and wait for buffer to be copied from source buffer.
159+
sem = self._device.create_semaphore(0)
160+
self._device.queue_copy(
161+
source_buffer,
162+
host_buffer,
163+
wait_semaphores=HalFence.create_at(sem, 0),
164+
signal_semaphores=HalFence.create_at(sem, 1),
165+
)
166+
HalFence.create_at(sem, 1).wait()
167+
# Map and reformat buffer as np.array.
168+
raw_dtype = self._get_raw_dtype()
169+
mapped_memory = host_buffer.map()
170+
host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype)
171+
# Detect if we need to force an explicit conversion. This happens when
172+
# we were requested to pretend that the array is in a specific dtype,
173+
# even if that is not representable on the device. You guessed it:
174+
# this is to support bools.
175+
if self._override_dtype is not None and self._override_dtype != raw_dtype:
176+
host_array = host_array.astype(self._override_dtype)
177+
return host_array
178+
132179
def _get_raw_dtype(self):
133180
return HalElementType.map_to_dtype(self._buffer_view.element_type)
134181

runtime/bindings/python/tests/hal_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,52 @@ def testFenceExtend(self):
265265
fence.extend(iree.runtime.HalFence.create_at(sem2, 2))
266266
self.assertEqual(fence.timepoint_count, 2)
267267

268+
def testRoundTripQueueCopy(self):
269+
original_ary = np.zeros([3, 4], dtype=np.int32) + 2
270+
source_bv = self.allocator.allocate_buffer_copy(
271+
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
272+
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
273+
device=self.device,
274+
buffer=original_ary,
275+
element_type=iree.runtime.HalElementType.SINT_32,
276+
)
277+
source_buffer = source_bv.get_buffer()
278+
target_buffer = self.allocator.allocate_buffer(
279+
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
280+
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
281+
allocation_size=source_buffer.byte_length(),
282+
)
283+
sem = self.device.create_semaphore(0)
284+
self.device.queue_copy(
285+
source_buffer,
286+
target_buffer,
287+
wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
288+
signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
289+
)
290+
iree.runtime.HalFence.create_at(sem, 1).wait()
291+
copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype)
292+
np.testing.assert_array_equal(original_ary, copy_ary)
293+
294+
def testDifferentSizeQueueCopy(self):
295+
source_buffer = self.allocator.allocate_buffer(
296+
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
297+
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
298+
allocation_size=12,
299+
)
300+
target_buffer = self.allocator.allocate_buffer(
301+
memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
302+
allowed_usage=iree.runtime.BufferUsage.DEFAULT,
303+
allocation_size=13,
304+
)
305+
sem = self.device.create_semaphore(0)
306+
with self.assertRaisesRegex(ValueError, "length must match"):
307+
self.device.queue_copy(
308+
source_buffer,
309+
target_buffer,
310+
wait_semaphores=iree.runtime.HalFence.create_at(sem, 0),
311+
signal_semaphores=iree.runtime.HalFence.create_at(sem, 1),
312+
)
313+
268314
def testCommandBufferStartsByDefault(self):
269315
cb = iree.runtime.HalCommandBuffer(self.device)
270316
with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"):

runtime/src/iree/hal/drivers/cuda/cuda_allocator.c

+3-3
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,12 @@ iree_hal_cuda_allocator_query_buffer_compatibility(
236236
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
237237
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE;
238238
}
239+
if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
240+
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
241+
}
239242

240243
// Buffers can only be used on the queue if they are device visible.
241244
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
242-
if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
243-
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
244-
}
245245
if (iree_any_bit_set(params->usage,
246246
IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
247247
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;

0 commit comments

Comments
 (0)