@@ -57,6 +57,18 @@ static const char kHalDeviceQueueExecute[] =
5757 signal_semaphores: Semaphores/Fence to signal.
5858)" ;
5959
60+ static const char kHalDeviceQueueCopy [] =
61+ R"( Copy data from a source buffer to destination buffer.
62+
63+ Args:
64+ source_buffer: `HalBuffer` that holds src data.
65+ target_buffer: `HalBuffer` that will receive data.
66+ wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or
67+ a HalFence. The allocation will be made once these semaphores are
68+ satisfied.
69+ signal_semaphores: Semaphores/Fence to signal.
70+ )" ;
71+
6072static const char kHalFenceWait [] =
6173 R"( Waits until the fence is signalled or errored.
6274
@@ -524,6 +536,69 @@ void HalDevice::QueueExecute(py::handle command_buffers,
524536 " executing command buffers" );
525537}
526538
539+ void HalDevice::QueueCopy (HalBuffer& source_buffer, HalBuffer& target_buffer,
540+ py::handle wait_semaphores,
541+ py::handle signal_semaphores) {
542+ iree_hal_semaphore_list_t wait_list;
543+ iree_hal_semaphore_list_t signal_list;
544+
545+ // Wait list.
546+ if (py::isinstance<HalFence>(wait_semaphores)) {
547+ wait_list = iree_hal_fence_semaphore_list (
548+ py::cast<HalFence*>(wait_semaphores)->raw_ptr ());
549+ } else {
550+ size_t wait_count = py::len (wait_semaphores);
551+ wait_list = {
552+ wait_count,
553+ /* semaphores=*/
554+ static_cast <iree_hal_semaphore_t **>(
555+ alloca (sizeof (iree_hal_semaphore_t *) * wait_count)),
556+ /* payload_values=*/
557+ static_cast <uint64_t *>(alloca (sizeof (uint64_t ) * wait_count)),
558+ };
559+ for (size_t i = 0 ; i < wait_count; ++i) {
560+ py::tuple pair = wait_semaphores[i];
561+ wait_list.semaphores [i] = py::cast<HalSemaphore*>(pair[0 ])->raw_ptr ();
562+ wait_list.payload_values [i] = py::cast<uint64_t >(pair[1 ]);
563+ }
564+ }
565+
566+ // Signal list.
567+ if (py::isinstance<HalFence>(signal_semaphores)) {
568+ signal_list = iree_hal_fence_semaphore_list (
569+ py::cast<HalFence*>(signal_semaphores)->raw_ptr ());
570+ } else {
571+ size_t signal_count = py::len (signal_semaphores);
572+ signal_list = {
573+ signal_count,
574+ /* semaphores=*/
575+ static_cast <iree_hal_semaphore_t **>(
576+ alloca (sizeof (iree_hal_semaphore_t *) * signal_count)),
577+ /* payload_values=*/
578+ static_cast <uint64_t *>(alloca (sizeof (uint64_t ) * signal_count)),
579+ };
580+ for (size_t i = 0 ; i < signal_count; ++i) {
581+ py::tuple pair = signal_semaphores[i];
582+ signal_list.semaphores [i] = py::cast<HalSemaphore*>(pair[0 ])->raw_ptr ();
583+ signal_list.payload_values [i] = py::cast<uint64_t >(pair[1 ]);
584+ }
585+ }
586+
587+ // TODO: Accept params for src_offset and target_offset.
588+ iree_device_size_t source_length =
589+ iree_hal_buffer_byte_length (source_buffer.raw_ptr ());
590+ if (source_length != iree_hal_buffer_byte_length (target_buffer.raw_ptr ())) {
591+ throw std::invalid_argument (
592+ " Source and target buffer length must match and it does not. Please "
593+ " check allocations" );
594+ }
595+ CheckApiStatus (iree_hal_device_queue_copy (
596+ raw_ptr (), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
597+ signal_list, source_buffer.raw_ptr (), 0 ,
598+ target_buffer.raw_ptr (), 0 , source_length),
599+ " Copying buffer on queue" );
600+ }
601+
527602// ------------------------------------------------------------------------------
528603// HalDriver
529604// ------------------------------------------------------------------------------
@@ -861,6 +936,9 @@ void SetupHalBindings(nanobind::module_ m) {
861936 .def (" queue_execute" , &HalDevice::QueueExecute,
862937 py::arg (" command_buffers" ), py::arg (" wait_semaphores" ),
863938 py::arg (" signal_semaphores" ), kHalDeviceQueueExecute )
939+ .def (" queue_copy" , &HalDevice::QueueCopy, py::arg (" source_buffer" ),
940+ py::arg (" target_buffer" ), py::arg (" wait_semaphores" ),
941+ py::arg (" signal_semaphores" ), kHalDeviceQueueCopy )
864942 .def (" __repr__" , [](HalDevice& self) {
865943 auto id_sv = iree_hal_device_id (self.raw_ptr ());
866944 return std::string (id_sv.data , id_sv.size );
@@ -963,6 +1041,9 @@ void SetupHalBindings(nanobind::module_ m) {
9631041 py::class_<HalBuffer>(m, " HalBuffer" )
9641042 .def (" fill_zero" , &HalBuffer::FillZero, py::arg (" byte_offset" ),
9651043 py::arg (" byte_length" ))
1044+ .def (" byte_length" , &HalBuffer::byte_length)
1045+ .def (" memory_type" , &HalBuffer::memory_type)
1046+ .def (" allowed_usage" , &HalBuffer::allowed_usage)
9661047 .def (" create_view" , &HalBuffer::CreateView, py::arg (" shape" ),
9671048 py::arg (" element_size" ), py::keep_alive<0 , 1 >())
9681049 .def (" map" , HalMappedMemory::CreateFromBuffer, py::keep_alive<0 , 1 >())
@@ -994,6 +1075,8 @@ void SetupHalBindings(nanobind::module_ m) {
9941075 py::arg (" buffer" ), py::arg (" shape" ), py::arg (" element_type" ));
9951076 hal_buffer_view
9961077 .def (" map" , HalMappedMemory::CreateFromBufferView, py::keep_alive<0 , 1 >())
1078+ .def (" get_buffer" , HalBuffer::CreateFromBufferView,
1079+ py::keep_alive<0 , 1 >())
9971080 .def_prop_ro (" shape" ,
9981081 [](HalBufferView& self) {
9991082 iree_host_size_t rank =
0 commit comments