diff --git a/axengine/_axclrt.py b/axengine/_axclrt.py index 731186f..e07fa59 100644 --- a/axengine/_axclrt.py +++ b/axengine/_axclrt.py @@ -377,13 +377,13 @@ def run( ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) if 0 != ret: raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.") + buffer_addr = dev_prt[0] npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) - npy = np.frombuffer( - axclrt_cffi.buffer( - self._io[0].pOutputs[i].pVirAddr, npy_size - ), - dtype=self.get_outputs(shape_group)[i].dtype, - ).reshape(self.get_outputs(shape_group)[i].shape) + npy = np.zeros(self.get_outputs(shape_group)[i].shape, dtype=self.get_outputs(shape_group)[i].dtype) + npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) + ret = axclrt_lib.axclrtMemcpy(npy_ptr, buffer_addr, npy_size, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST) + if 0 != ret: + raise RuntimeError(f"axclrtMemcpy failed for output {i}.") name = self.get_outputs(shape_group)[i].name if name in output_names: outputs.append(npy)