From bcc8b8dc8ced7e680605c0742953f281eebe8285 Mon Sep 17 00:00:00 2001 From: wanglusheng Date: Sun, 9 Feb 2025 14:41:13 +0800 Subject: [PATCH] fix cffi life cycle --- axengine/_axclrt.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/axengine/_axclrt.py b/axengine/_axclrt.py index 980fa14..559e716 100644 --- a/axengine/_axclrt.py +++ b/axengine/_axclrt.py @@ -96,6 +96,8 @@ def __init__( super().__init__() self._device_index = 0 + self._io = None + self._model_id = None if provider_options is not None and "device_id" in provider_options[0]: self._device_index = provider_options[0].get("device_id", 0) @@ -214,12 +216,12 @@ def _unload(self): dev_size = axclrt_cffi.new("uint64_t *") dev_prt = axclrt_cffi.new("void **") for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])): - axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size) + axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size) axclrt_lib.axclrtFree(dev_prt[0]) for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])): - axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size) + axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) axclrt_lib.axclrtFree(dev_prt[0]) - axclrt_lib.axclrtEngineDestroyIO(self._io) + axclrt_lib.axclrtEngineDestroyIO(self._io[0]) self._io = None if self._model_id[0] is not None and self._model_id[0] != 0: axclrt_lib.axclrtEngineUnload(self._model_id[0]) @@ -322,7 +324,7 @@ def _prepare_io(self): ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size) if 0 != ret: raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.") - return _io[0] + return _io def run( self, @@ -353,7 +355,7 @@ def run( if not (npy.flags.c_contiguous or npy.flags.f_contiguous): npy = np.ascontiguousarray(npy) npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) - ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size) + ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size) if 0 != ret: raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.") ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE) @@ -361,13 +363,13 @@ def run( raise RuntimeError(f"axclrtMemcpy failed for input {i}.") # execute model - ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io) + ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io[0]) # get output outputs = [] if 0 == ret: for i in range(len(self.get_outputs())): - ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size) + ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) if 0 != ret: raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.") npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)