@@ -57,6 +57,18 @@ static const char kHalDeviceQueueExecute[] =
57
57
signal_semaphores: Semaphores/Fence to signal.
58
58
)" ;
59
59
60
+ static const char kHalDeviceQueueCopy [] =
61
+ R"( Copy data from a source buffer to destination buffer.
62
+
63
+ Args:
64
+ source_buffer: `HalBuffer` that holds src data.
65
+ target_buffer: `HalBuffer` that will receive data.
66
+ wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
67
+ a HalFence. The allocation will be made once these semaphores are
68
+ satisfied.
69
+ signal_semaphores: Semaphores/Fence to signal.
70
+ )" ;
71
+
60
72
static const char kHalFenceWait [] =
61
73
R"( Waits until the fence is signalled or errored.
62
74
@@ -524,6 +536,69 @@ void HalDevice::QueueExecute(py::handle command_buffers,
524
536
" executing command buffers" );
525
537
}
526
538
539
+ void HalDevice::QueueCopy (HalBuffer& source_buffer, HalBuffer& target_buffer,
540
+ py::handle wait_semaphores,
541
+ py::handle signal_semaphores) {
542
+ iree_hal_semaphore_list_t wait_list;
543
+ iree_hal_semaphore_list_t signal_list;
544
+
545
+ // Wait list.
546
+ if (py::isinstance<HalFence>(wait_semaphores)) {
547
+ wait_list = iree_hal_fence_semaphore_list (
548
+ py::cast<HalFence*>(wait_semaphores)->raw_ptr ());
549
+ } else {
550
+ size_t wait_count = py::len (wait_semaphores);
551
+ wait_list = {
552
+ wait_count,
553
+ /* semaphores=*/
554
+ static_cast <iree_hal_semaphore_t **>(
555
+ alloca (sizeof (iree_hal_semaphore_t *) * wait_count)),
556
+ /* payload_values=*/
557
+ static_cast <uint64_t *>(alloca (sizeof (uint64_t ) * wait_count)),
558
+ };
559
+ for (size_t i = 0 ; i < wait_count; ++i) {
560
+ py::tuple pair = wait_semaphores[i];
561
+ wait_list.semaphores [i] = py::cast<HalSemaphore*>(pair[0 ])->raw_ptr ();
562
+ wait_list.payload_values [i] = py::cast<uint64_t >(pair[1 ]);
563
+ }
564
+ }
565
+
566
+ // Signal list.
567
+ if (py::isinstance<HalFence>(signal_semaphores)) {
568
+ signal_list = iree_hal_fence_semaphore_list (
569
+ py::cast<HalFence*>(signal_semaphores)->raw_ptr ());
570
+ } else {
571
+ size_t signal_count = py::len (signal_semaphores);
572
+ signal_list = {
573
+ signal_count,
574
+ /* semaphores=*/
575
+ static_cast <iree_hal_semaphore_t **>(
576
+ alloca (sizeof (iree_hal_semaphore_t *) * signal_count)),
577
+ /* payload_values=*/
578
+ static_cast <uint64_t *>(alloca (sizeof (uint64_t ) * signal_count)),
579
+ };
580
+ for (size_t i = 0 ; i < signal_count; ++i) {
581
+ py::tuple pair = signal_semaphores[i];
582
+ signal_list.semaphores [i] = py::cast<HalSemaphore*>(pair[0 ])->raw_ptr ();
583
+ signal_list.payload_values [i] = py::cast<uint64_t >(pair[1 ]);
584
+ }
585
+ }
586
+
587
+ // TODO: Accept params for src_offset and target_offset.
588
+ iree_device_size_t source_length =
589
+ iree_hal_buffer_byte_length (source_buffer.raw_ptr ());
590
+ if (source_length != iree_hal_buffer_byte_length (target_buffer.raw_ptr ())) {
591
+ throw std::invalid_argument (
592
+ " Source and target buffer length must match and it does not. Please "
593
+ " check allocations" );
594
+ }
595
+ CheckApiStatus (iree_hal_device_queue_copy (
596
+ raw_ptr (), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
597
+ signal_list, source_buffer.raw_ptr (), 0 ,
598
+ target_buffer.raw_ptr (), 0 , source_length),
599
+ " Copying buffer on queue" );
600
+ }
601
+
527
602
// ------------------------------------------------------------------------------
528
603
// HalDriver
529
604
// ------------------------------------------------------------------------------
@@ -861,6 +936,9 @@ void SetupHalBindings(nanobind::module_ m) {
861
936
.def (" queue_execute" , &HalDevice::QueueExecute,
862
937
py::arg (" command_buffers" ), py::arg (" wait_semaphores" ),
863
938
py::arg (" signal_semaphores" ), kHalDeviceQueueExecute )
939
+ .def (" queue_copy" , &HalDevice::QueueCopy, py::arg (" source_buffer" ),
940
+ py::arg (" target_buffer" ), py::arg (" wait_semaphores" ),
941
+ py::arg (" signal_semaphores" ), kHalDeviceQueueCopy )
864
942
.def (" __repr__" , [](HalDevice& self) {
865
943
auto id_sv = iree_hal_device_id (self.raw_ptr ());
866
944
return std::string (id_sv.data , id_sv.size );
@@ -963,6 +1041,9 @@ void SetupHalBindings(nanobind::module_ m) {
963
1041
py::class_<HalBuffer>(m, " HalBuffer" )
964
1042
.def (" fill_zero" , &HalBuffer::FillZero, py::arg (" byte_offset" ),
965
1043
py::arg (" byte_length" ))
1044
+ .def (" byte_length" , &HalBuffer::byte_length)
1045
+ .def (" memory_type" , &HalBuffer::memory_type)
1046
+ .def (" allowed_usage" , &HalBuffer::allowed_usage)
966
1047
.def (" create_view" , &HalBuffer::CreateView, py::arg (" shape" ),
967
1048
py::arg (" element_size" ), py::keep_alive<0 , 1 >())
968
1049
.def (" map" , HalMappedMemory::CreateFromBuffer, py::keep_alive<0 , 1 >())
@@ -994,6 +1075,8 @@ void SetupHalBindings(nanobind::module_ m) {
994
1075
py::arg (" buffer" ), py::arg (" shape" ), py::arg (" element_type" ));
995
1076
hal_buffer_view
996
1077
.def (" map" , HalMappedMemory::CreateFromBufferView, py::keep_alive<0 , 1 >())
1078
+ .def (" get_buffer" , HalBuffer::CreateFromBufferView,
1079
+ py::keep_alive<0 , 1 >())
997
1080
.def_prop_ro (" shape" ,
998
1081
[](HalBufferView& self) {
999
1082
iree_host_size_t rank =
0 commit comments