diff --git a/axengine/_axclrt.py b/axengine/_axclrt.py index e07fa59..109d329 100644 --- a/axengine/_axclrt.py +++ b/axengine/_axclrt.py @@ -372,8 +372,10 @@ def run( # get output outputs = [] + origin_output_names = [_o.name for _o in self.get_outputs(shape_group)] + outputs_ranks = [output_names.index(_on) for _on in origin_output_names] if 0 == ret: - for i in range(len(self.get_outputs(shape_group))): + for i in outputs_ranks: ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) if 0 != ret: raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.") diff --git a/axengine/_axe.py b/axengine/_axe.py index 6bfb431..f6deffb 100644 --- a/axengine/_axe.py +++ b/axengine/_axe.py @@ -395,8 +395,10 @@ def run( # flush output outputs = [] + origin_output_names = [_o.name for _o in self.get_outputs(shape_group)] + outputs_ranks = [output_names.index(_on) for _on in origin_output_names] if 0 == ret: - for i in range(len(self.get_outputs(shape_group))): + for i in outputs_ranks: sys_lib.AX_SYS_MinvalidateCache( self._io[0].pOutputs[i].phyAddr, self._io[0].pOutputs[i].pVirAddr,