From a488a260edc440e9e8698ea7f8eb56f08eaedf3f Mon Sep 17 00:00:00 2001 From: kaclohol <314377460@qq.com> Date: Mon, 3 Mar 2025 15:01:58 +0800 Subject: [PATCH] add shape group support --- axengine/_axclrt.py | 28 +++++++++++++++++----------- axengine/_axe.py | 32 +++++++++++++++++++++----------- axengine/_session.py | 5 +++-- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/axengine/_axclrt.py b/axengine/_axclrt.py index 559e716..73a75d5 100644 --- a/axengine/_axclrt.py +++ b/axengine/_axclrt.py @@ -330,7 +330,8 @@ def run( self, output_names: list[str], input_feed: dict[str, np.ndarray], - run_options=None + run_options=None, + shape_group: int = 0 ): self._validate_input(input_feed) self._validate_output(output_names) @@ -340,13 +341,16 @@ def run( raise RuntimeError("axclrtSetCurrentContext failed") if None is output_names: - output_names = [o.name for o in self.get_outputs()] + output_names = [o.name for o in self.get_outputs(shape_group)] + + if (shape_group > self._shape_count - 1) or (shape_group < 0): + raise ValueError(f"Invalid shape group: {shape_group}") # fill model io dev_prt = axclrt_cffi.new("void **") dev_size = axclrt_cffi.new("uint64_t *") for key, npy in input_feed.items(): - for i, one in enumerate(self.get_inputs()): + for i, one in enumerate(self.get_inputs(shape_group)): if one.name == key: assert ( list(one.shape) == list(npy.shape) and one.dtype == npy.dtype @@ -363,21 +367,23 @@ def run( raise RuntimeError(f"axclrtMemcpy failed for input {i}.") # execute model - ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io[0]) + ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], shape_group, self._io[0]) # get output outputs = [] if 0 == ret: - for i in range(len(self.get_outputs())): + for i in range(len(self.get_outputs(shape_group))): ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) if 0 != ret: raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.") - npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype) - npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) - ret = axclrt_lib.axclrtMemcpy(npy_ptr, dev_prt[0], npy.nbytes, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST) - if 0 != ret: - raise RuntimeError(f"axclrtMemcpy failed for output {i}.") - name = self.get_outputs()[i].name + npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) + npy = np.frombuffer( + axclrt_cffi.buffer( + self._io[0].pOutputs[i].pVirAddr, npy_size + ), + dtype=self.get_outputs(shape_group)[i].dtype, + ).reshape(self.get_outputs(shape_group)[i].shape) + name = self.get_outputs(shape_group)[i].name if name in output_names: outputs.append(npy) return outputs diff --git a/axengine/_axe.py b/axengine/_axe.py index 3adb634..09f7e72 100644 --- a/axengine/_axe.py +++ b/axengine/_axe.py @@ -346,17 +346,21 @@ def run( self, output_names: list[str], input_feed: dict[str, np.ndarray], - run_options=None + run_options=None, + shape_group: int = 0 ): self._validate_input(input_feed) self._validate_output(output_names) if None is output_names: - output_names = [o.name for o in self.get_outputs()] + output_names = [o.name for o in self.get_outputs(shape_group)] + + if (shape_group > self._shape_count - 1) or (shape_group < 0): + raise ValueError(f"Invalid shape group: {shape_group}") # fill model io for key, npy in input_feed.items(): - for i, one in enumerate(self.get_inputs()): + for i, one in enumerate(self.get_inputs(shape_group)): if one.name == key: assert ( list(one.shape) == list(npy.shape) and one.dtype == npy.dtype @@ -377,26 +381,32 @@ def run( break # execute model - ret = engine_lib.AX_ENGINE_RunSyncV2( - self._handle[0], self._context[0], self._io - ) + if self._shape_count > 1: + ret = engine_lib.AX_ENGINE_RunGroupIOSync( + self._handle[0], self._context[0], shape_group, self._io + ) + else: + ret = engine_lib.AX_ENGINE_RunSyncV2( + self._handle[0], self._context[0], self._io + ) # flush output outputs = [] if 0 == ret: - for i in range(len(self.get_outputs())): + for i in range(len(self.get_outputs(shape_group))): sys_lib.AX_SYS_MinvalidateCache( self._io[0].pOutputs[i].phyAddr, self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize, ) + npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) npy = np.frombuffer( engine_cffi.buffer( - self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize + self._io[0].pOutputs[i].pVirAddr, npy_size ), - dtype=self.get_outputs()[i].dtype, - ).reshape(self.get_outputs()[i].shape) - name = self.get_outputs()[i].name + dtype=self.get_outputs(shape_group)[i].dtype, + ).reshape(self.get_outputs(shape_group)[i].shape) + name = self.get_outputs(shape_group)[i].name if name in output_names: outputs.append(npy) return outputs diff --git a/axengine/_session.py b/axengine/_session.py index 1f321b4..ab452ba 100644 --- a/axengine/_session.py +++ b/axengine/_session.py @@ -112,6 +112,7 @@ def run( self, output_names: list[str] | None, input_feed: dict[str, np.ndarray], - run_options=None + run_options=None, + shape_group: int = 0 ) -> list[np.ndarray]: - return self._sess.run(output_names, input_feed, run_options) + return self._sess.run(output_names, input_feed, run_options, shape_group)