Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions axengine/_axclrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def run(
self,
output_names: list[str],
input_feed: dict[str, np.ndarray],
run_options=None
run_options=None,
shape_group: int = 0
):
self._validate_input(input_feed)
self._validate_output(output_names)
Expand All @@ -340,13 +341,16 @@ def run(
raise RuntimeError("axclrtSetCurrentContext failed")

if None is output_names:
output_names = [o.name for o in self.get_outputs()]
output_names = [o.name for o in self.get_outputs(shape_group)]

if (shape_group > self._shape_count - 1) or (shape_group < 0):
raise ValueError(f"Invalid shape group: {shape_group}")

# fill model io
dev_prt = axclrt_cffi.new("void **")
dev_size = axclrt_cffi.new("uint64_t *")
for key, npy in input_feed.items():
for i, one in enumerate(self.get_inputs()):
for i, one in enumerate(self.get_inputs(shape_group)):
if one.name == key:
assert (
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
Expand All @@ -363,21 +367,23 @@ def run(
raise RuntimeError(f"axclrtMemcpy failed for input {i}.")

# execute model
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io[0])
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], shape_group, self._io[0])

# get output
outputs = []
if 0 == ret:
for i in range(len(self.get_outputs())):
for i in range(len(self.get_outputs(shape_group))):
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
if 0 != ret:
raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
ret = axclrt_lib.axclrtMemcpy(npy_ptr, dev_prt[0], npy.nbytes, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST)
if 0 != ret:
raise RuntimeError(f"axclrtMemcpy failed for output {i}.")
name = self.get_outputs()[i].name
npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape)
npy = np.frombuffer(
axclrt_cffi.buffer(
self._io[0].pOutputs[i].pVirAddr, npy_size
),
dtype=self.get_outputs(shape_group)[i].dtype,
).reshape(self.get_outputs(shape_group)[i].shape)
name = self.get_outputs(shape_group)[i].name
if name in output_names:
outputs.append(npy)
return outputs
Expand Down
32 changes: 21 additions & 11 deletions axengine/_axe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,21 @@ def run(
self,
output_names: list[str],
input_feed: dict[str, np.ndarray],
run_options=None
run_options=None,
shape_group: int = 0
):
self._validate_input(input_feed)
self._validate_output(output_names)

if None is output_names:
output_names = [o.name for o in self.get_outputs()]
output_names = [o.name for o in self.get_outputs(shape_group)]

if (shape_group > self._shape_count - 1) or (shape_group < 0):
raise ValueError(f"Invalid shape group: {shape_group}")

# fill model io
for key, npy in input_feed.items():
for i, one in enumerate(self.get_inputs()):
for i, one in enumerate(self.get_inputs(shape_group)):
if one.name == key:
assert (
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
Expand All @@ -377,26 +381,32 @@ def run(
break

# execute model
ret = engine_lib.AX_ENGINE_RunSyncV2(
self._handle[0], self._context[0], self._io
)
if self._shape_count > 1:
ret = engine_lib.AX_ENGINE_RunGroupIOSync(
self._handle[0], self._context[0], shape_group, self._io
)
else:
ret = engine_lib.AX_ENGINE_RunSyncV2(
self._handle[0], self._context[0], self._io
)

# flush output
outputs = []
if 0 == ret:
for i in range(len(self.get_outputs())):
for i in range(len(self.get_outputs(shape_group))):
sys_lib.AX_SYS_MinvalidateCache(
self._io[0].pOutputs[i].phyAddr,
self._io[0].pOutputs[i].pVirAddr,
self._io[0].pOutputs[i].nSize,
)
npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape)
npy = np.frombuffer(
engine_cffi.buffer(
self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
self._io[0].pOutputs[i].pVirAddr, npy_size
),
dtype=self.get_outputs()[i].dtype,
).reshape(self.get_outputs()[i].shape)
name = self.get_outputs()[i].name
dtype=self.get_outputs(shape_group)[i].dtype,
).reshape(self.get_outputs(shape_group)[i].shape)
name = self.get_outputs(shape_group)[i].name
if name in output_names:
outputs.append(npy)
return outputs
Expand Down
5 changes: 3 additions & 2 deletions axengine/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def run(
self,
output_names: list[str] | None,
input_feed: dict[str, np.ndarray],
run_options=None
run_options=None,
shape_group: int = 0
) -> list[np.ndarray]:
return self._sess.run(output_names, input_feed, run_options)
return self._sess.run(output_names, input_feed, run_options, shape_group)