Skip to content

Commit

Permalink
fix cffi life cycle
Browse files Browse the repository at this point in the history
  • Loading branch information
wanglusheng authored and kalcohol committed Feb 9, 2025
1 parent e28b968 commit bcc8b8d
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions axengine/_axclrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
super().__init__()

self._device_index = 0
self._io = None
self._model_id = None

if provider_options is not None and "device_id" in provider_options[0]:
self._device_index = provider_options[0].get("device_id", 0)
Expand Down Expand Up @@ -214,12 +216,12 @@ def _unload(self):
dev_size = axclrt_cffi.new("uint64_t *")
dev_prt = axclrt_cffi.new("void **")
for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size)
axclrt_lib.axclrtFree(dev_prt[0])
for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
axclrt_lib.axclrtFree(dev_prt[0])
axclrt_lib.axclrtEngineDestroyIO(self._io)
axclrt_lib.axclrtEngineDestroyIO(self._io[0])
self._io = None
if self._model_id[0] is not None and self._model_id[0] != 0:
axclrt_lib.axclrtEngineUnload(self._model_id[0])
Expand Down Expand Up @@ -322,7 +324,7 @@ def _prepare_io(self):
ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
if 0 != ret:
raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.")
return _io[0]
return _io

def run(
self,
Expand Down Expand Up @@ -353,21 +355,21 @@ def run(
if not (npy.flags.c_contiguous or npy.flags.f_contiguous):
npy = np.ascontiguousarray(npy)
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size)
if 0 != ret:
raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.")
ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
if 0 != ret:
raise RuntimeError(f"axclrtMemcpy failed for input {i}.")

# execute model
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io)
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io[0])

# get output
outputs = []
if 0 == ret:
for i in range(len(self.get_outputs())):
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
if 0 != ret:
raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)
Expand Down

0 comments on commit bcc8b8d

Please sign in to comment.