diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1c5063a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + pip install types-cffi + + - name: Lint with ruff + run: ruff check . + + - name: Type check with mypy + run: mypy axengine + + - name: Test with pytest + run: pytest -m "not hardware" --cov=axengine --cov-report=term-missing diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4cf011c --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.sisyphus/ +__pycache__/ +.coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a126cac --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.9 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + args: [--python-version=3.8] diff --git a/axengine/__init__.py b/axengine/__init__.py index bacaa74..895c264 100644 --- a/axengine/__init__.py +++ b/axengine/__init__.py @@ -8,15 +8,38 @@ # thanks to community contributors list below: # zylo117: https://github.com/zylo117, first implementation of the axclrt backend -from ._providers import axengine_provider_name, axclrt_provider_name -from ._providers import get_all_providers, get_available_providers +import os + +from ._logging import get_logger +from ._node import NodeArg as NodeArg +from ._providers import ( + axclrt_provider_name as axclrt_provider_name, +) +from ._providers import ( + axengine_provider_name as axengine_provider_name, +) +from ._providers import ( + get_all_providers as get_all_providers, +) +from ._providers import ( + get_available_providers, +) +from ._session import InferenceSession as InferenceSession +from ._session import SessionOptions as SessionOptions + +logger = get_logger(__name__) -# check if axclrt is installed, or is a supported chip(e.g. AX650, AX620E etc.) _available_providers = get_available_providers() -if not _available_providers: - raise ImportError( - f"No providers found. Please make sure you have installed one of the following: {get_all_providers()}") -print("[INFO] Available providers: ", _available_providers) +_is_test_or_ci = bool(os.getenv("CI") or os.getenv("PYTEST_CURRENT_TEST")) -from ._node import NodeArg -from ._session import SessionOptions, InferenceSession +if not _available_providers: + _provider_error_message = ( + "No execution providers available. Install the required hardware libraries " + "(ax_engine or axcl_rt) and check that the target hardware/driver is available." + ) + if _is_test_or_ci: + logger.warning(_provider_error_message) + else: + raise ImportError(_provider_error_message) +else: + logger.info("Available providers: %s", _available_providers) diff --git a/axengine/_axclrt.py b/axengine/_axclrt.py index 109d329..d79f2b8 100644 --- a/axengine/_axclrt.py +++ b/axengine/_axclrt.py @@ -8,43 +8,25 @@ import atexit import os -import time -from typing import Any, Sequence +from typing import Any -import ml_dtypes as mldt import numpy as np from ._axclrt_capi import axclrt_cffi, axclrt_lib -from ._axclrt_types import VNPUType, ModelType +from ._axclrt_types import VNPUType from ._base_session import Session, SessionOptions +from ._logging import get_logger from ._node import NodeArg +from ._utils_axclrt import _transform_dtype_axclrt as _transform_dtype -__all__: ["AXCLRTSession"] +logger = get_logger(__name__) + +__all__ = ["AXCLRTSession"] _is_axclrt_initialized = False _is_axclrt_engine_initialized = False -_all_model_instances = [] - - -def _transform_dtype(dtype): - if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8): - return np.dtype(np.uint8) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8): - return np.dtype(np.int8) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16): - return np.dtype(np.uint16) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16): - return np.dtype(np.int16) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32): - return np.dtype(np.uint32) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32): - return np.dtype(np.int32) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32): - return np.dtype(np.float32) - elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16): - return np.dtype(mldt.bfloat16) - else: - raise ValueError(f"Unsupported data type '{dtype}'.") +_all_model_instances: list[Any] = [] + def _initialize_axclrt(): global _is_axclrt_initialized @@ -72,62 +54,66 @@ def _finalize_axclrt(): def _get_vnpu_type() -> VNPUType: vnpu_type = axclrt_cffi.new("axclrtEngineVNpuKind *") - ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu_type) + ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu_type) # type: ignore[attr-defined] if ret != 0: raise RuntimeError("Failed to get VNPU attribute.") return VNPUType(vnpu_type[0]) def _get_version(): - major, minor, patch = axclrt_cffi.new('int32_t *'), axclrt_cffi.new('int32_t *'), axclrt_cffi.new( - 'int32_t *') + major, minor, patch = axclrt_cffi.new("int32_t *"), axclrt_cffi.new("int32_t *"), axclrt_cffi.new("int32_t *") axclrt_lib.axclrtGetVersion(major, minor, patch) - return f'{major[0]}.{minor[0]}.{patch[0]}' + return f"{major[0]}.{minor[0]}.{patch[0]}" class AXCLRTSession(Session): + """AXCL runtime-backed session for loading and executing AX models. + + Attributes: + soc_name: The SOC name reported by the AXCL runtime. + _device_index: The selected device index used for this session. + """ + def __init__( - self, - path_or_bytes: str | bytes | os.PathLike, - sess_options: SessionOptions | None = None, - provider_options: dict[Any, Any] | None = None, - **kwargs, + self, + path_or_bytes: str | bytes | os.PathLike, + sess_options: SessionOptions | None = None, + provider_options: dict[Any, Any] | None = None, + **kwargs, ) -> None: super().__init__() self._device_index = 0 - self._io = None - self._model_id = None + self._io: Any | None = None + self._model_id: Any | None = None - if provider_options is not None and "device_id" in provider_options[0]: - self._device_index = provider_options[0].get("device_id", 0) + if provider_options is not None and isinstance(provider_options, dict) and "device_id" in provider_options: + self._device_index = provider_options.get("device_id", 0) lst = axclrt_cffi.new("axclrtDeviceList *") - ret = axclrt_lib.axclrtGetDeviceList(lst) - if ret != 0 or lst.num == 0: - raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.") + ret = axclrt_lib.axclrtGetDeviceList(lst) # type: ignore[attr-defined] + if ret != 0 or lst.num == 0: # type: ignore[attr-defined] + raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.") # type: ignore[attr-defined] - if self._device_index >= lst.num: - raise RuntimeError(f"Device index {self._device_index} is out of range, total {lst.num} device.") + if self._device_index >= lst.num: # type: ignore[attr-defined] + raise RuntimeError(f"Device index {self._device_index} is out of range, total {lst.num} device.") # type: ignore[attr-defined] - self._device_id = lst.devices[self._device_index] - ret = axclrt_lib.axclrtSetDevice(self._device_id) - if ret != 0 or lst.num == 0: + self._device_id = lst.devices[self._device_index] # type: ignore[attr-defined] + ret = axclrt_lib.axclrtSetDevice(self._device_id) # type: ignore[attr-defined] + if ret != 0 or lst.num == 0: # type: ignore[attr-defined] raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.") global _is_axclrt_engine_initialized - vnpu_type = axclrt_cffi.cast( - "axclrtEngineVNpuKind", VNPUType.DISABLED.value - ) + vnpu_type = axclrt_cffi.cast("axclrtEngineVNpuKind", VNPUType.DISABLED.value) # try to initialize NPU as disabled - ret = axclrt_lib.axclrtEngineInit(vnpu_type) + ret = axclrt_lib.axclrtEngineInit(vnpu_type) # type: ignore[attr-defined] # if failed, try to get vnpu type if 0 != ret: vnpu = axclrt_cffi.new("axclrtEngineVNpuKind *") - ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu) + ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu) # type: ignore[attr-defined] # if failed, that means the NPU is not available if ret != 0: - raise RuntimeError(f"axclrtEngineInit as {vnpu.value} failed 0x{ret:08x}.") + raise RuntimeError(f"axclrtEngineInit as {vnpu.value} failed 0x{ret:08x}.") # type: ignore[attr-defined] # if success, that means the NPU is already initialized as vnpu.value # so the initialization is failed. # this means the other users maybe uninitialized the NPU suddenly @@ -136,16 +122,16 @@ def __init__( # it because the api looks like onnxruntime, so there no window avoid this. # such as the life. else: - print(f"[WARNING] Failed to initialize NPU as {vnpu_type}, NPU is already initialized as {vnpu.value}.") + logger.warning(f"Failed to initialize NPU as {vnpu_type}, NPU is already initialized as {vnpu.value}.") # type: ignore[attr-defined] # initialize NPU successfully, mark the flag to ensure the engine will be finalized else: _is_axclrt_engine_initialized = True - self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode() - print(f"[INFO] SOC Name: {self.soc_name}") + self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode() # type: ignore[union-attr,attr-defined] + logger.info(f"SOC Name: {self.soc_name}") self._thread_context = axclrt_cffi.new("axclrtContext *") - ret = axclrt_lib.axclrtGetCurrentContext(self._thread_context) + ret = axclrt_lib.axclrtGetCurrentContext(self._thread_context) # type: ignore[attr-defined] if ret != 0: raise RuntimeError("axclrtGetCurrentContext failed") @@ -155,13 +141,13 @@ def __init__( # get vnpu type self._vnpu_type = _get_vnpu_type() - print(f"[INFO] VNPU type: {self._vnpu_type}") + logger.info(f"VNPU type: {self._vnpu_type}") # load model ret = self._load(path_or_bytes) if 0 != ret: raise RuntimeError("Failed to load model.") - print(f"[INFO] Compiler version: {self._get_model_tool_version()}") + logger.info(f"Compiler version: {self._get_model_tool_version()}") # get model info self._info = self._get_info() @@ -178,10 +164,17 @@ def __del__(self): self._unload() _all_model_instances.remove(self) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._unload() + return False + def _load(self, path_or_bytes): # model buffer, almost copied from onnx runtime if isinstance(path_or_bytes, (str, os.PathLike)): - _model_path = axclrt_cffi.new("char[]", path_or_bytes.encode('utf-8')) + _model_path = axclrt_cffi.new("char[]", path_or_bytes.encode("utf-8")) ret = axclrt_lib.axclrtEngineLoadFromFile(_model_path, self._model_id) if ret != 0: raise RuntimeError("axclrtEngineLoadFromFile failed.") @@ -189,12 +182,14 @@ def _load(self, path_or_bytes): _model_buffer = axclrt_cffi.new("char[]", path_or_bytes) _model_buffer_size = len(path_or_bytes) - dev_mem_ptr = axclrt_cffi.new('void **', axclrt_cffi.NULL) + dev_mem_ptr = axclrt_cffi.new("void **", axclrt_cffi.NULL) ret = axclrt_lib.axclrtMalloc(dev_mem_ptr, _model_buffer_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY) if ret != 0: raise RuntimeError("axclrtMalloc failed.") - ret = axclrt_lib.axclrtMemcpy(dev_mem_ptr[0], _model_buffer, _model_buffer_size, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE) + ret = axclrt_lib.axclrtMemcpy( + dev_mem_ptr[0], _model_buffer, _model_buffer_size, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE + ) if ret != 0: axclrt_lib.axclrtFree(dev_mem_ptr[0]) raise RuntimeError("axclrtMemcpy failed.") @@ -276,7 +271,7 @@ def _get_outputs(self): for group in range(self._shape_count): one_group_io = [] for index in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])): - cffi_name = axclrt_lib.axclrtEngineGetOutputNameByIndex(self._info[0], index) + cffi_name = axclrt_lib.axclrtEngineGetOutputNameByIndex(self._info[0], index) name = axclrt_cffi.string(cffi_name).decode("utf-8") cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *") @@ -328,16 +323,19 @@ def _prepare_io(self): return _io def run( - self, - output_names: list[str], - input_feed: dict[str, np.ndarray], - run_options=None, - shape_group: int = 0 - ): + self, + output_names: list[str] | None, + input_feed: dict[str, np.ndarray], + run_options: object | None = None, + shape_group: int = 0, + ) -> list[np.ndarray]: self._validate_input(input_feed) self._validate_output(output_names) - ret = axclrt_lib.axclrtSetCurrentContext(self._thread_context[0]) + if self._io is None: + raise RuntimeError("IO not initialized") + + ret = axclrt_lib.axclrtSetCurrentContext(self._thread_context[0]) # type: ignore[attr-defined] if ret != 0: raise RuntimeError("axclrtSetCurrentContext failed") @@ -353,22 +351,29 @@ def run( for key, npy in input_feed.items(): for i, one in enumerate(self.get_inputs(shape_group)): if one.name == key: - assert ( - list(one.shape) == list(npy.shape) and one.dtype == npy.dtype - ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}" + assert list(one.shape) == list(npy.shape) and one.dtype == npy.dtype, ( + f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}" + ) if not (npy.flags.c_contiguous or npy.flags.f_contiguous): npy = np.ascontiguousarray(npy) npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) - ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size) + ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io[0], i, dev_prt, dev_size) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.") - ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE) + ret = axclrt_lib.axclrtMemcpy( # type: ignore[attr-defined] + dev_prt[0], + npy_ptr, + npy.nbytes, + axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE, # type: ignore[attr-defined] + ) if 0 != ret: raise RuntimeError(f"axclrtMemcpy failed for input {i}.") - # execute model - ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], shape_group, self._io[0]) + if self._model_id is None or self._context_id is None: + raise RuntimeError("Model or context not initialized") + + ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], shape_group, self._io[0]) # type: ignore[attr-defined] # get output outputs = [] @@ -376,14 +381,16 @@ def run( outputs_ranks = [output_names.index(_on) for _on in origin_output_names] if 0 == ret: for i in outputs_ranks: - ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) + ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.") buffer_addr = dev_prt[0] - npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) + npy_size = np.dtype(self.get_outputs(shape_group)[i].dtype).itemsize * np.prod( + self.get_outputs(shape_group)[i].shape + ) npy = np.zeros(self.get_outputs(shape_group)[i].shape, dtype=self.get_outputs(shape_group)[i].dtype) npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data) - ret = axclrt_lib.axclrtMemcpy(npy_ptr, buffer_addr, npy_size, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST) + ret = axclrt_lib.axclrtMemcpy(npy_ptr, buffer_addr, npy_size, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError(f"axclrtMemcpy failed for output {i}.") name = self.get_outputs(shape_group)[i].name diff --git a/axengine/_axclrt_capi.py b/axengine/_axclrt_capi.py index 1719a94..7d644af 100644 --- a/axengine/_axclrt_capi.py +++ b/axengine/_axclrt_capi.py @@ -9,7 +9,7 @@ from cffi import FFI -__all__: ["axclrt_cffi", "axclrt_lib"] +__all__: list[str] = ["axclrt_cffi", "axclrt_lib"] axclrt_cffi = FFI() @@ -29,13 +29,13 @@ uint32_t num; int32_t devices[AXCL_MAX_DEVICE_COUNT]; } axclrtDeviceList; - + typedef enum axclrtMemMallocPolicy { AXCL_MEM_MALLOC_HUGE_FIRST, AXCL_MEM_MALLOC_HUGE_ONLY, AXCL_MEM_MALLOC_NORMAL_ONLY } axclrtMemMallocPolicy; - + typedef enum axclrtMemcpyKind { AXCL_MEMCPY_HOST_TO_HOST, AXCL_MEMCPY_HOST_TO_DEVICE, //!< host vir -> device phy @@ -60,7 +60,7 @@ AXCL_VNPU_BIG_LITTLE = 2, AXCL_VNPU_LITTLE_BIG = 3, } axclrtEngineVNpuKind; - + typedef enum axclrtEngineDataType { AXCL_DATA_TYPE_NONE = 0, AXCL_DATA_TYPE_INT4 = 1, @@ -80,13 +80,13 @@ AXCL_DATA_TYPE_FP32 = 15, AXCL_DATA_TYPE_FP64 = 16, } axclrtEngineDataType; - + typedef enum axclrtEngineDataLayout { AXCL_DATA_LAYOUT_NONE = 0, AXCL_DATA_LAYOUT_NHWC = 0, AXCL_DATA_LAYOUT_NCHW = 1, } axclrtEngineDataLayout; - + typedef struct axclrtEngineIODims { int32_t dimCount; int32_t dims[AXCLRT_ENGINE_MAX_DIM_CNT]; @@ -190,9 +190,9 @@ rt_name = "axcl_rt" rt_path = ctypes.util.find_library(rt_name) -assert ( - rt_path is not None -), f"Failed to find library {rt_name}. Please ensure it is installed and in the library path." +if rt_path is None: + raise ImportError(f"Failed to find library {rt_name}. Please ensure it is installed and in the library path.") axclrt_lib = axclrt_cffi.dlopen(rt_path) -assert axclrt_lib is not None, f"Failed to load library {rt_path}. Please ensure it is installed and in the library path." +if axclrt_lib is None: + raise ImportError(f"Failed to load library {rt_path}. Please ensure it is installed and in the library path.") diff --git a/axengine/_axe.py b/axengine/_axe.py index f6deffb..cbe5933 100644 --- a/axengine/_axe.py +++ b/axengine/_axe.py @@ -7,43 +7,25 @@ import atexit import os -from typing import Any, Sequence +from typing import Any -import ml_dtypes as mldt import numpy as np -from ._axe_capi import sys_lib, engine_cffi, engine_lib -from ._axe_types import VNPUType, ModelType, ChipType +from ._axe_capi import engine_cffi, engine_lib, sys_lib +from ._axe_types import ChipType, ModelType, VNPUType from ._base_session import Session, SessionOptions +from ._logging import get_logger from ._node import NodeArg +from ._utils import _transform_dtype -__all__: ["AXEngineSession"] +logger = get_logger(__name__) + +__all__ = ["AXEngineSession"] _is_sys_initialized = False _is_engine_initialized = False -def _transform_dtype(dtype): - if dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT8): - return np.dtype(np.uint8) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT8): - return np.dtype(np.int8) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT16): - return np.dtype(np.uint16) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT16): - return np.dtype(np.int16) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT32): - return np.dtype(np.uint32) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT32): - return np.dtype(np.int32) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_FLOAT32): - return np.dtype(np.float32) - elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_BFLOAT16): - return np.dtype(mldt.bfloat16) - else: - raise ValueError(f"Unsupported data type '{dtype}'.") - - def _check_cffi_func_exists(lib, func_name): try: getattr(lib, func_name) @@ -68,10 +50,10 @@ def _get_version(): def _get_vnpu_type() -> VNPUType: vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *") - ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type) + ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError("Failed to get VNPU attribute.") - return VNPUType(vnpu_type.eHardMode) + return VNPUType(vnpu_type.eHardMode) # type: ignore[attr-defined] def _initialize_engine(): @@ -87,17 +69,15 @@ def _initialize_engine(): ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type) if 0 != ret: # this means the NPU was not initialized - vnpu_type.eHardMode = engine_cffi.cast( - "AX_ENGINE_NPU_MODE_T", VNPUType.DISABLED.value - ) + vnpu_type.eHardMode = engine_cffi.cast("AX_ENGINE_NPU_MODE_T", VNPUType.DISABLED.value) ret = engine_lib.AX_ENGINE_Init(vnpu_type) if ret != 0: raise RuntimeError("Failed to initialize ax sys engine.") _is_engine_initialized = True - print(f"[INFO] Chip type: {_get_chip_type()}") - print(f"[INFO] VNPU type: {_get_vnpu_type()}") - print(f"[INFO] Engine version: {_get_version()}") + logger.info(f"Chip type: {_get_chip_type()}") + logger.info(f"VNPU type: {_get_vnpu_type()}") + logger.info(f"Engine version: {_get_version()}") def _finalize_engine(): @@ -114,12 +94,19 @@ def _finalize_engine(): class AXEngineSession(Session): + """ONNXRuntime-compatible inference session backed by AxEngine runtime. + + This session loads an AX model from a file path or in-memory bytes, + prepares input/output metadata, and executes synchronous inference through + the Axera engine C API. + """ + def __init__( - self, - path_or_bytes: str | bytes | os.PathLike, - sess_options: SessionOptions | None = None, - provider_options: dict[Any, Any] | None = None, - **kwargs, + self, + path_or_bytes: str | bytes | os.PathLike[str], + sess_options: SessionOptions | None = None, + provider_options: dict[Any, Any] | None = None, + **kwargs, ) -> None: super().__init__() @@ -151,18 +138,18 @@ def __init__( self._model_type = self._get_model_type() if self._chip_type is ChipType.MC20E: if self._model_type is ModelType.FULL: - print(f"[INFO] Model type: {self._model_type.value} (full core)") + logger.info(f"Model type: {self._model_type.value} (full core)") if self._model_type is ModelType.HALF: - print(f"[INFO] Model type: {self._model_type.value} (half core)") + logger.info(f"Model type: {self._model_type.value} (half core)") if self._chip_type is ChipType.MC50: if self._model_type is ModelType.SINGLE: - print(f"[INFO] Model type: {self._model_type.value} (single core)") + logger.info(f"Model type: {self._model_type.value} (single core)") if self._model_type is ModelType.DUAL: - print(f"[INFO] Model type: {self._model_type.value} (dual core)") + logger.info(f"Model type: {self._model_type.value} (dual core)") if self._model_type is ModelType.TRIPLE: - print(f"[INFO] Model type: {self._model_type.value} (triple core)") + logger.info(f"Model type: {self._model_type.value} (triple core)") if self._chip_type is ChipType.M57H: - print(f"[INFO] Model type: {self._model_type.value} (single core)") + logger.info(f"Model type: {self._model_type.value} (single core)") # check model type if self._chip_type is ChipType.MC50: @@ -174,10 +161,7 @@ def __init__( raise ValueError( f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." ) - if ( - self._vnpu_type is VNPUType.BIG_LITTLE - or self._vnpu_type is VNPUType.LITTLE_BIG - ): + if self._vnpu_type is VNPUType.BIG_LITTLE or self._vnpu_type is VNPUType.LITTLE_BIG: if self._model_type is ModelType.TRIPLE: raise ValueError( f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}." @@ -197,13 +181,13 @@ def __init__( ret = self._load() if 0 != ret: raise RuntimeError("Failed to load model.") - print(f"[INFO] Compiler version: {self._get_model_tool_version()}") + logger.info(f"Compiler version: {self._get_model_tool_version()}") # get shape group count try: self._shape_count = self._get_shape_count() except AttributeError as e: - print(f"[WARNING] {e}") + logger.warning(f"{e}") self._shape_count = 1 # get model shape @@ -216,12 +200,8 @@ def __init__( self._cmm_token = engine_cffi.new("AX_S8[]", b"PyEngine") self._io[0].nInputSize = len(self.get_inputs()) self._io[0].nOutputSize = len(self.get_outputs()) - _inputs= engine_cffi.new( - "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize) - ) - _outputs = engine_cffi.new( - "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize) - ) + _inputs = engine_cffi.new("AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize)) + _outputs = engine_cffi.new("AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize)) self._io_buffers = (_inputs, _outputs) self._io[0].pInputs = _inputs self._io[0].pOutputs = _outputs @@ -235,9 +215,7 @@ def __init__( phy = engine_cffi.new("AX_U64*") vir = engine_cffi.new("AX_VOID**") self._io_inputs_pool.append((phy, vir)) - ret = sys_lib.AX_SYS_MemAllocCached( - phy, vir, self._io[0].pInputs[i].nSize, self._align, self._cmm_token - ) + ret = sys_lib.AX_SYS_MemAllocCached(phy, vir, self._io[0].pInputs[i].nSize, self._align, self._cmm_token) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError("Failed to allocate memory for input.") self._io[0].pInputs[i].phyAddr = phy[0] @@ -252,30 +230,31 @@ def __init__( phy = engine_cffi.new("AX_U64*") vir = engine_cffi.new("AX_VOID**") self._io_outputs_pool.append((phy, vir)) - ret = sys_lib.AX_SYS_MemAllocCached( - phy, vir, self._io[0].pOutputs[i].nSize, self._align, self._cmm_token - ) + ret = sys_lib.AX_SYS_MemAllocCached(phy, vir, self._io[0].pOutputs[i].nSize, self._align, self._cmm_token) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError("Failed to allocate memory for output.") self._io[0].pOutputs[i].phyAddr = phy[0] self._io[0].pOutputs[i].pVirAddr = vir[0] + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._unload() + return False + def __del__(self): self._unload() def _get_model_type(self) -> ModelType: model_type = engine_cffi.new("AX_ENGINE_MODEL_TYPE_T *") - ret = engine_lib.AX_ENGINE_GetModelType( - self._model_buffer, self._model_buffer_size, model_type - ) + ret = engine_lib.AX_ENGINE_GetModelType(self._model_buffer, self._model_buffer_size, model_type) # type: ignore[attr-defined] if 0 != ret: raise RuntimeError("Failed to get model type.") return ModelType(model_type[0]) def _get_model_tool_version(self): - model_tool_version = engine_lib.AX_ENGINE_GetModelToolsVersion( - self._handle[0] - ) + model_tool_version = engine_lib.AX_ENGINE_GetModelToolsVersion(self._handle[0]) return engine_cffi.string(model_tool_version).decode("utf-8") def _load(self): @@ -285,13 +264,9 @@ def _load(self): # for onnx runtime do not support one model multiple context running in multi-thread as far as I know, so # the engine handle and context will create only once - ret = engine_lib.AX_ENGINE_CreateHandleV2( - self._handle, self._model_buffer, self._model_buffer_size, extra - ) + ret = engine_lib.AX_ENGINE_CreateHandleV2(self._handle, self._model_buffer, self._model_buffer_size, extra) if 0 == ret: - ret = engine_lib.AX_ENGINE_CreateContextV2( - self._handle[0], self._context - ) + ret = engine_lib.AX_ENGINE_CreateContextV2(self._handle[0], self._context) return ret def _get_info(self): @@ -305,9 +280,7 @@ def _get_info(self): else: for i in range(self._shape_count): info = engine_cffi.new("AX_ENGINE_IO_INFO_T **") - ret = engine_lib.AX_ENGINE_GetGroupIOInfo( - self._handle[0], i, info - ) + ret = engine_lib.AX_ENGINE_GetGroupIOInfo(self._handle[0], i, info) if 0 != ret: raise RuntimeError(f"Failed to get model the {i}th shape.") total_info.append(info) @@ -329,10 +302,10 @@ def _get_io(self, io_type: str): io_info = [] for group in range(self._shape_count): one_group_io = [] - for index in range(getattr(self._info[group][0], f'n{io_type}Size')): - current_io = getattr(self._info[group][0], f'p{io_type}s')[index] - name = engine_cffi.string(current_io.pName).decode("utf-8") - shape = [current_io.pShape[i] for i in range(current_io.nShapeSize)] + for index in range(getattr(self._info[group][0], f"n{io_type}Size")): + current_io = getattr(self._info[group][0], f"p{io_type}s")[index] + name = engine_cffi.string(current_io.pName).decode("utf-8") # type: ignore[union-attr] + shape = tuple(current_io.pShape[i] for i in range(current_io.nShapeSize)) dtype = _transform_dtype(current_io.eDataType) meta = NodeArg(name, dtype, shape) one_group_io.append(meta) @@ -340,23 +313,24 @@ def _get_io(self, io_type: str): return io_info def _get_inputs(self): - return self._get_io('Input') + return self._get_io("Input") def _get_outputs(self): - return self._get_io('Output') + return self._get_io("Output") def run( - self, - output_names: list[str], - input_feed: dict[str, np.ndarray], - run_options=None, - shape_group: int = 0 - ): + self, + output_names: list[str] | None, + input_feed: dict[str, np.ndarray], + run_options: object | None = None, + shape_group: int = 0, + ) -> list[np.ndarray]: self._validate_input(input_feed) self._validate_output(output_names) if None is output_names: output_names = [o.name for o in self.get_outputs(shape_group)] + assert output_names is not None if (shape_group > self._shape_count - 1) or (shape_group < 0): raise ValueError(f"Invalid shape group: {shape_group}") @@ -365,18 +339,16 @@ def run( for key, npy in input_feed.items(): for i, one in enumerate(self.get_inputs(shape_group)): if one.name == key: - assert ( - list(one.shape) == list(npy.shape) and one.dtype == npy.dtype - ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, however gets input with shape {npy.shape} and dtype {npy.dtype}" + assert list(one.shape) == list(npy.shape) and one.dtype == npy.dtype, ( + f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, however gets input with shape {npy.shape} and dtype {npy.dtype}" + ) if not (npy.flags.c_contiguous or npy.flags.f_contiguous): npy = np.ascontiguousarray(npy) npy_ptr = engine_cffi.cast("void *", npy.ctypes.data) - engine_cffi.memmove( - self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes - ) - sys_lib.AX_SYS_MflushCache( + engine_cffi.memmove(self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes) + sys_lib.AX_SYS_MflushCache( # type: ignore[attr-defined] self._io[0].pInputs[i].phyAddr, self._io[0].pInputs[i].pVirAddr, self._io[0].pInputs[i].nSize, @@ -385,13 +357,9 @@ def run( # execute model if self._shape_count > 1: - ret = engine_lib.AX_ENGINE_RunGroupIOSync( - self._handle[0], self._context[0], shape_group, self._io - ) + ret = engine_lib.AX_ENGINE_RunGroupIOSync(self._handle[0], self._context[0], shape_group, self._io) # type: ignore[attr-defined] else: - ret = engine_lib.AX_ENGINE_RunSyncV2( - self._handle[0], self._context[0], self._io - ) + ret = engine_lib.AX_ENGINE_RunSyncV2(self._handle[0], self._context[0], self._io) # type: ignore[attr-defined] # flush output outputs = [] @@ -399,18 +367,22 @@ def run( outputs_ranks = [output_names.index(_on) for _on in origin_output_names] if 0 == ret: for i in outputs_ranks: - sys_lib.AX_SYS_MinvalidateCache( + sys_lib.AX_SYS_MinvalidateCache( # type: ignore[attr-defined] self._io[0].pOutputs[i].phyAddr, self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize, ) - npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape) - npy = np.frombuffer( - engine_cffi.buffer( - self._io[0].pOutputs[i].pVirAddr, npy_size - ), - dtype=self.get_outputs(shape_group)[i].dtype, - ).reshape(self.get_outputs(shape_group)[i].shape).copy() + npy_size = np.dtype(self.get_outputs(shape_group)[i].dtype).itemsize * np.prod( + self.get_outputs(shape_group)[i].shape + ) + npy = ( + np.frombuffer( + engine_cffi.buffer(self._io[0].pOutputs[i].pVirAddr, npy_size), + dtype=self.get_outputs(shape_group)[i].dtype, # type: ignore[call-overload] + ) + .reshape(self.get_outputs(shape_group)[i].shape) + .copy() + ) name = self.get_outputs(shape_group)[i].name if name in output_names: outputs.append(npy) diff --git a/axengine/_axe_capi.py b/axengine/_axe_capi.py index 2d9ecec..248f613 100644 --- a/axengine/_axe_capi.py +++ b/axengine/_axe_capi.py @@ -10,7 +10,7 @@ from cffi import FFI -__all__: ["sys_lib", "sys_cffi", "engine_lib", "engine_cffi"] +__all__: list[str] = ["sys_lib", "sys_cffi", "engine_lib", "engine_cffi"] sys_cffi = FFI() @@ -39,12 +39,12 @@ sys_name = "ax_sys" sys_path = ctypes.util.find_library(sys_name) -assert ( - sys_path is not None -), f"Failed to find library {sys_name}. Please ensure it is installed and in the library path." +if sys_path is None: + raise ImportError(f"Failed to find library {sys_name}. Please ensure it is installed and in the library path.") sys_lib = sys_cffi.dlopen(sys_path) -assert sys_lib is not None, f"Failed to load library {sys_path}. Please ensure it is installed and in the library path." +if sys_lib is None: + raise ImportError(f"Failed to load library {sys_path}. Please ensure it is installed and in the library path.") engine_cffi = FFI() @@ -58,7 +58,7 @@ typedef signed char AX_S8; typedef char AX_CHAR; typedef void AX_VOID; - + typedef enum { AX_FALSE = 0, AX_TRUE = 1, @@ -148,13 +148,13 @@ AX_ENGINE_COLOR_SPACE_T eColorSpace; AX_U64 u64Reserved[18]; } AX_ENGINE_IO_META_EX_T; - + typedef struct { AX_ENGINE_NPU_SET_T nNpuSet; AX_S8* pName; AX_U32 reserve[8]; } AX_ENGINE_HANDLE_EXTRA_T; - + typedef struct _AX_ENGINE_CMM_INFO_T { AX_U32 nCMMSize; @@ -315,9 +315,11 @@ engine_name = "ax_engine" engine_path = ctypes.util.find_library(engine_name) -assert ( - engine_path is not None -), f"Failed to find library {engine_name}. Please ensure it is installed and in the library path." +assert engine_path is not None, ( + f"Failed to find library {engine_name}. Please ensure it is installed and in the library path." +) engine_lib = engine_cffi.dlopen(engine_path) -assert engine_lib is not None, f"Failed to load library {engine_path}. Please ensure it is installed and in the library path." +assert engine_lib is not None, ( + f"Failed to load library {engine_path}. Please ensure it is installed and in the library path." +) diff --git a/axengine/_base_session.py b/axengine/_base_session.py index 86d25c0..ff6e32f 100644 --- a/axengine/_base_session.py +++ b/axengine/_base_session.py @@ -13,37 +13,72 @@ class SessionOptions: - pass + """Configuration options for session initialization. + + Stores session-level configuration parameters used when creating + and initializing a session instance. + """ + + pass # Placeholder for future session configuration options class Session(ABC): + """Base class for inference sessions. + + Provides common interface for running model inference on Axera NPU devices. + Supports multiple shape groups for dynamic input/output configurations. + """ + def __init__(self) -> None: self._shape_count = 0 - self._inputs = [] - self._outputs = [] + self._inputs: list[list[NodeArg]] = [] + self._outputs: list[list[NodeArg]] = [] - def _validate_input(self, feed_input_names: dict[str, np.ndarray]): + def _validate_input(self, feed_input_names: dict[str, np.ndarray]) -> None: missing_input_names = [] for i in self.get_inputs(): if i.name not in feed_input_names: missing_input_names.append(i.name) if missing_input_names: raise ValueError( - f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names}).") + f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})." + ) - def _validate_output(self, output_names: list[str]): + def _validate_output(self, output_names: list[str] | None) -> None: if output_names is not None: for name in output_names: if name not in [o.name for o in self.get_outputs()]: raise ValueError(f"Output name '{name}' is not in model outputs name list.") def get_inputs(self, shape_group: int = 0) -> list[NodeArg]: + """Get input node information for the specified shape group. + + Args: + shape_group: Index of the shape group (default: 0). + + Returns: + List of input NodeArg objects for the shape group. + + Raises: + ValueError: If shape_group is out of range. + """ if shape_group > self._shape_count: raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") selected_info = self._inputs[shape_group] return selected_info def get_outputs(self, shape_group: int = 0) -> list[NodeArg]: + """Get output node information for the specified shape group. + + Args: + shape_group: Index of the shape group (default: 0). + + Returns: + List of output NodeArg objects for the shape group. + + Raises: + ValueError: If shape_group is out of range. + """ if shape_group > self._shape_count: raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") selected_info = self._outputs[shape_group] @@ -51,9 +86,16 @@ def get_outputs(self, shape_group: int = 0) -> list[NodeArg]: @abstractmethod def run( - self, - output_names: list[str] | None, - input_feed: dict[str, np.ndarray], - run_options=None + self, output_names: list[str] | None, input_feed: dict[str, np.ndarray], run_options: object | None = None ) -> list[np.ndarray]: + """Run inference on the model. + + Args: + output_names: Names of outputs to retrieve, or None for all outputs. + input_feed: Dictionary mapping input names to numpy arrays. + run_options: Optional runtime configuration. + + Returns: + List of output numpy arrays. + """ pass diff --git a/axengine/_logging.py b/axengine/_logging.py new file mode 100644 index 0000000..f46cc74 --- /dev/null +++ b/axengine/_logging.py @@ -0,0 +1,27 @@ +"""Unified logging infrastructure for axengine.""" + +import logging +import os + + +def get_logger(name: str) -> logging.Logger: + """Get a logger instance with unified configuration. + + Args: + name: Logger name (typically __name__) + + Returns: + Configured logger instance + """ + logger = logging.getLogger(name) + + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + + log_level = os.environ.get("AXENGINE_LOG_LEVEL", "INFO").upper() + logger.setLevel(getattr(logging, log_level, logging.INFO)) + + return logger diff --git a/axengine/_node.py b/axengine/_node.py index cf0459e..451b3eb 100644 --- a/axengine/_node.py +++ b/axengine/_node.py @@ -7,7 +7,15 @@ class NodeArg(object): - def __init__(self, name, dtype, shape): - self.name = name - self.dtype = dtype - self.shape = shape + """Represents a node argument with type and shape information. + + Attributes: + name: The name of the argument. + dtype: The data type of the argument (e.g., 'float32', 'int64'). + shape: The shape of the argument as a tuple of integers. + """ + + def __init__(self, name: str, dtype: str, shape: tuple[int, ...]) -> None: + self.name: str = name + self.dtype: str = dtype + self.shape: tuple[int, ...] = shape diff --git a/axengine/_providers.py b/axengine/_providers.py index dfab02e..c4ffb7e 100644 --- a/axengine/_providers.py +++ b/axengine/_providers.py @@ -8,11 +8,11 @@ import ctypes.util as cutil providers = [] -axengine_provider_name = 'AxEngineExecutionProvider' -axclrt_provider_name = 'AXCLRTExecutionProvider' +axengine_provider_name = "AxEngineExecutionProvider" +axclrt_provider_name = "AXCLRTExecutionProvider" -_axengine_lib_name = 'ax_engine' -_axclrt_lib_name = 'axcl_rt' +_axengine_lib_name = "ax_engine" +_axclrt_lib_name = "axcl_rt" # check if axcl_rt is installed, so if available, it's the default provider if cutil.find_library(_axclrt_lib_name) is not None: @@ -23,9 +23,9 @@ providers.append(axengine_provider_name) -def get_all_providers(): +def get_all_providers() -> list[str]: return [axengine_provider_name, axclrt_provider_name] -def get_available_providers(): +def get_available_providers() -> list[str]: return providers diff --git a/axengine/_session.py b/axengine/_session.py index ab452ba..a092998 100644 --- a/axengine/_session.py +++ b/axengine/_session.py @@ -6,29 +6,60 @@ # import os -from typing import Any, Sequence +from typing import Any, cast import numpy as np from ._base_session import SessionOptions +from ._logging import get_logger from ._node import NodeArg -from ._providers import axclrt_provider_name, axengine_provider_name -from ._providers import get_available_providers +from ._providers import axclrt_provider_name, axengine_provider_name, get_available_providers + +logger = get_logger(__name__) class InferenceSession: + """Inference session for running ONNX models on NPU. + + This class provides a high-level interface for loading and running ONNX models + on Axera NPU devices. It supports multiple execution providers and is compatible + with ONNXRuntime API. + + Attributes: + _sess: Internal session object for the selected provider. + _provider: Name of the execution provider being used. + _provider_options: Configuration options for the provider. + """ + def __init__( - self, - path_or_bytes: str | bytes | os.PathLike, - sess_options: SessionOptions | None = None, - providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, - provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs, + self, + path_or_bytes: str | bytes | os.PathLike, + sess_options: SessionOptions | None = None, + providers: str | list[str | tuple[str, dict[Any, Any]]] | None = None, + provider_options: list[dict[Any, Any]] | None = None, + **kwargs, ) -> None: - self._sess = None + """Initialize an InferenceSession. + + Args: + path_or_bytes: Path to the ONNX model file or model bytes. + sess_options: Session configuration options. Defaults to None. + providers: Execution provider(s) to use. Can be a string for single provider + or list of strings/tuples for multiple providers. Defaults to None (uses first available). + provider_options: Provider-specific configuration options. Defaults to None. + **kwargs: Additional arguments passed to the provider session. + + Raises: + ValueError: If selected provider is not available or no valid provider found. + TypeError: If provider format is invalid. + RuntimeError: If session creation fails. + """ + self._sess: Any self._sess_options = sess_options - self._provider = None - self._provider_options = None + self._provider: str | None = None + self._provider_options: dict[Any, Any] | None = None self._available_providers = get_available_providers() + sess: Any | None = None # the providers should be available at least one, checked in __init__.py if providers is None: @@ -45,74 +76,116 @@ def __init__( elif isinstance(providers, list): _unavailable_provider = [] for p in providers: - assert isinstance(p, str) or isinstance(p, tuple), \ - f"Invalid provider type: {type(p)}. Must be str or tuple." + if not (isinstance(p, str) or isinstance(p, tuple)): + raise TypeError(f"Invalid provider type: {type(p)}. Must be str or tuple.") if isinstance(p, str): if p not in self._available_providers: _unavailable_provider.append(p) elif self._provider is None: self._provider = p - if isinstance(p, tuple): - assert len(p) == 2, f"Invalid provider type: {p}. Must be tuple with 2 elements." - assert isinstance(p[0], str), f"Invalid provider type: {type(p[0])}. Must be str." - assert isinstance(p[1], dict), f"Invalid provider type: {type(p[1])}. Must be dict." + elif isinstance(p, tuple): + if len(p) != 2: + raise ValueError(f"Invalid provider type: {p}. Must be tuple with 2 elements.") + if not isinstance(p[0], str): + raise TypeError(f"Invalid provider type: {type(p[0])}. Must be str.") + if not isinstance(p[1], dict): + raise TypeError(f"Invalid provider type: {type(p[1])}. Must be dict.") if p[0] not in self._available_providers: _unavailable_provider.append(p[0]) elif self._provider is None: self._provider = p[0] - # FIXME: check provider options + # Provider options dict is validated above (line 91-92). + # Provider-specific validation happens in session constructors. self._provider_options = p[1] if _unavailable_provider: if self._provider is None: raise ValueError(f"Selected provider(s): {_unavailable_provider} is(are) not available.") else: - print(f"[WARNING] Selected provider(s): {_unavailable_provider} is(are) not available.") + logger.warning(f"Selected provider(s): {_unavailable_provider} is(are) not available.") + + logger.info(f"Using provider: {self._provider}") - # FIXME: can we remove this check? - if self._provider is None: - raise ValueError(f"No available provider found in {providers}.") - print(f"[INFO] Using provider: {self._provider}") + provider_opts = None + if self._provider_options is not None: + provider_opts = self._provider_options + elif provider_options is not None and len(provider_options) > 0: + provider_opts = provider_options[0] if self._provider == axclrt_provider_name: from ._axclrt import AXCLRTSession - self._sess = AXCLRTSession(path_or_bytes, sess_options, provider_options, **kwargs) + + sess = AXCLRTSession(path_or_bytes, sess_options, provider_opts, **kwargs) if self._provider == axengine_provider_name: from ._axe import AXEngineSession - self._sess = AXEngineSession(path_or_bytes, sess_options, provider_options, **kwargs) - if self._sess is None: + + sess = AXEngineSession(path_or_bytes, sess_options, provider_opts, **kwargs) + if sess is None: raise RuntimeError(f"Create session failed with provider: {self._provider}") + self._sess = sess - # add to support 'with' statement def __enter__(self): + """Enter context manager.""" + self._sess.__enter__() return self def __exit__(self, exc_type, exc_value, traceback): - # not suppress exceptions - return False + """Exit context manager.""" + return self._sess.__exit__(exc_type, exc_value, traceback) - def get_session_options(self): - """ - Return the session options. See :class:`axengine.SessionOptions`. + def get_session_options(self) -> SessionOptions | None: + """Get session options. + + Returns: + SessionOptions: The session configuration options. """ return self._sess_options - def get_providers(self): - """ - Return list of registered execution providers. + def get_providers(self) -> str | None: + """Get the execution provider name. + + Returns: + str: Name of the registered execution provider. """ return self._provider def get_inputs(self, shape_group: int = 0) -> list[NodeArg]: - return self._sess.get_inputs(shape_group) + """Get model input metadata. + + Args: + shape_group: Shape group index for dynamic-shape models. + + Returns: + list[NodeArg]: Input node metadata. + """ + return cast(list[NodeArg], self._sess.get_inputs(shape_group)) def get_outputs(self, shape_group: int = 0) -> list[NodeArg]: - return self._sess.get_outputs(shape_group) + """Get model output metadata. + + Args: + shape_group: Shape group index for dynamic-shape models. + + Returns: + list[NodeArg]: Output node metadata. + """ + return cast(list[NodeArg], self._sess.get_outputs(shape_group)) def run( - self, - output_names: list[str] | None, - input_feed: dict[str, np.ndarray], - run_options=None, - shape_group: int = 0 + self, + output_names: list[str] | None, + input_feed: dict[str, np.ndarray], + run_options: object | None = None, + shape_group: int = 0, ) -> list[np.ndarray]: - return self._sess.run(output_names, input_feed, run_options, shape_group) + """Run inference with given model inputs. + + Args: + output_names: Optional output names to fetch, or None for all outputs. + input_feed: Input tensor mapping keyed by model input name. + run_options: Optional runtime options for provider-specific execution. + shape_group: Shape group index for dynamic-shape models. + + Returns: + list[np.ndarray]: Inference outputs in model-defined order. + """ + return cast(list[np.ndarray], self._sess.run(output_names, input_feed, run_options, shape_group)) diff --git a/axengine/_utils.py b/axengine/_utils.py new file mode 100644 index 0000000..2530d56 --- /dev/null +++ b/axengine/_utils.py @@ -0,0 +1,25 @@ +import ml_dtypes as mldt +import numpy as np + +from ._axe_capi import engine_cffi, engine_lib + + +def _transform_dtype(dtype): + if dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT8): + return np.dtype(np.uint8) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT8): + return np.dtype(np.int8) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT16): + return np.dtype(np.uint16) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT16): + return np.dtype(np.int16) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT32): + return np.dtype(np.uint32) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT32): + return np.dtype(np.int32) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_FLOAT32): + return np.dtype(np.float32) + elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_BFLOAT16): + return np.dtype(mldt.bfloat16) + else: + raise ValueError(f"Unsupported data type '{dtype}'.") diff --git a/axengine/_utils_axclrt.py b/axengine/_utils_axclrt.py new file mode 100644 index 0000000..37f0656 --- /dev/null +++ b/axengine/_utils_axclrt.py @@ -0,0 +1,30 @@ +# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. +# +# AXCL dtype helpers — kept separate from _utils so AxEngine (board) path does not +# import axcl_rt at module load time. + +import ml_dtypes as mldt +import numpy as np + +from ._axclrt_capi import axclrt_cffi, axclrt_lib + + +def _transform_dtype_axclrt(dtype): + if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8): + return np.dtype(np.uint8) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8): + return np.dtype(np.int8) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16): + return np.dtype(np.uint16) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16): + return np.dtype(np.int16) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32): + return np.dtype(np.uint32) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32): + return np.dtype(np.int32) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32): + return np.dtype(np.float32) + elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16): + return np.dtype(mldt.bfloat16) + else: + raise ValueError(f"Unsupported data type '{dtype}'.") diff --git a/examples/yolov5.py b/examples/yolov5.py index 26c8562..5a1fdb2 100644 --- a/examples/yolov5.py +++ b/examples/yolov5.py @@ -197,13 +197,13 @@ def letterbox_yolov5( - im, - new_shape=(640, 640), - color=(114, 114, 114), - auto=True, - scaleFill=False, - scaleup=True, - stride=32, + im, + new_shape=(640, 640), + color=(114, 114, 114), + auto=True, + scaleFill=False, + scaleup=True, + stride=32, ): # Resize and pad image while meeting stride-multiple constraints shape = im.shape[:2] # current shape [height, width] @@ -233,9 +233,7 @@ def letterbox_yolov5( im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - im = cv2.copyMakeBorder( - im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color - ) # add border + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border return im, ratio, (dw, dh) @@ -251,16 +249,14 @@ def draw_bbox(image, bboxes, classes=None, show_label=True, threshold=0.1): """ bboxes: [x_min, y_min, x_max, y_max, probability, cls_id] format coordinates. """ - if classes == None: + if classes is None: classes = {v: k for k, v in COCO_CATEGORIES.items()} num_classes = len(classes) image_h, image_w, _ = image.shape hsv_tuples = [(1.0 * x / num_classes, 1.0, 1.0) for x in range(num_classes)] colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) - colors = list( - map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors) - ) + colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) random.seed(0) random.shuffle(colors) @@ -277,14 +273,10 @@ def draw_bbox(image, bboxes, classes=None, show_label=True, threshold=0.1): bbox_thick = int(0.6 * (image_h + image_w) / 600) c1, c2 = (coor[0], coor[1]), (coor[2], coor[3]) cv2.rectangle(image, c1, c2, bbox_color, bbox_thick) - print( - f" {class_ind:>3}: {CLASS_NAMES[class_ind]:<10}: {coor}, score: {score*100:3.2f}%" - ) + print(f" {class_ind:>3}: {CLASS_NAMES[class_ind]:<10}: {coor}, score: {score * 100:3.2f}%") if show_label: bbox_mess = "%s: %.2f" % (CLASS_NAMES[class_ind], score) - t_size = cv2.getTextSize( - bbox_mess, 0, fontScale, thickness=bbox_thick // 2 - )[0] + t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick // 2)[0] cv2.rectangle(image, c1, (c1[0] + t_size[0], c1[1] - t_size[1] - 3), bbox_color, -1) cv2.putText( @@ -347,9 +339,7 @@ def nms(proposals, iou_threshold, conf_threshold, multi_label=False): if nonzero_indices.size < 0: return i, j = nonzero_indices.T - bboxes = np.hstack( - (bboxes[i], proposals[i, j + 5][:, None], j[:, None].astype(float)) - ) + bboxes = np.hstack((bboxes[i], proposals[i, j + 5][:, None], j[:, None].astype(float))) else: confidences = proposals[:, 5:] conf = confidences.max(axis=1, keepdims=True) @@ -373,9 +363,7 @@ def nms(proposals, iou_threshold, conf_threshold, multi_label=False): max_ind = np.argmax(cls_bboxes[:, 4]) best_bbox = cls_bboxes[max_ind] best_bboxes.append(best_bbox) - cls_bboxes = np.concatenate( - [cls_bboxes[:max_ind], cls_bboxes[max_ind + 1:]] - ) + cls_bboxes = np.concatenate([cls_bboxes[:max_ind], cls_bboxes[max_ind + 1 :]]) iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4]) weight = np.ones((len(iou),), dtype=np.float32) @@ -427,12 +415,8 @@ def clip_coords(boxes, shape): def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): if ratio_pad is None: - gain = min( - img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1] - ) # gain = old / new - pad = (img1_shape[1] - img0_shape[1] * gain) / 2, ( - img1_shape[0] - img0_shape[0] * gain - ) / 2 # wh padding + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding else: gain = ratio_pad[0][0] pad = ratio_pad[1] @@ -453,8 +437,8 @@ def post_processing(outputs, origin_shape, input_shape): return pred -def detect_yolov5(model_path, image_path, save_path, repeat_times, selected_provider='AUTO', selected_device_id=0): - if selected_provider == 'AUTO': +def detect_yolov5(model_path, image_path, save_path, repeat_times, selected_provider="AUTO", selected_device_id=0): + if selected_provider == "AUTO": # Use AUTO to let the pyengine choose the first available provider session = axe.InferenceSession(model_path) else: @@ -508,26 +492,20 @@ def error(self, message): if __name__ == "__main__": ap = ExampleParser(description="YOLOv5 example") - ap.add_argument('-m', '--model-path', type=str, help='model path', required=True) - ap.add_argument('-i', '--image-path', type=str, help='image path', required=True) - ap.add_argument( - '-s', "--save-path", type=str, default="YOLOv5_OUT.jpg", help="detected output image save path" - ) - ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=10) + ap.add_argument("-m", "--model-path", type=str, help="model path", required=True) + ap.add_argument("-i", "--image-path", type=str, help="image path", required=True) + ap.add_argument("-s", "--save-path", type=str, default="YOLOv5_OUT.jpg", help="detected output image save path") + ap.add_argument("-r", "--repeat", type=int, help="repeat times", default=10) ap.add_argument( - '-p', - '--provider', + "-p", + "--provider", type=str, choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"], help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"', - default='AUTO' + default="AUTO", ) ap.add_argument( - '-d', - '--device-id', - type=int, - help=R'axclrt device index, depends on how many cards inserted', - default=0 + "-d", "--device-id", type=int, help=R"axclrt device index, depends on how many cards inserted", default=0 ) args = ap.parse_args() diff --git a/mypy_check.txt b/mypy_check.txt new file mode 100644 index 0000000..ee88a80 --- /dev/null +++ b/mypy_check.txt @@ -0,0 +1,19 @@ +pyproject.toml: [mypy]: python_version: Python 3.8 is not supported (must be 3.9 or higher) +axengine/_axe_capi.py:11: error: Library stubs not installed for "cffi" [import-untyped] +axengine/_axclrt_capi.py:10: error: Library stubs not installed for "cffi" [import-untyped] +axengine/_axclrt_capi.py:10: note: Hint: "python3 -m pip install types-cffi" +axengine/_axclrt_capi.py:10: note: (or run "mypy --install-types" to install all missing stub packages) +axengine/_axclrt_capi.py:10: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports +axengine/_axe.py:303: error: Argument 3 to "NodeArg" has incompatible type "list[Any]"; expected "tuple[int, ...]" [arg-type] +axengine/_axe.py:367: error: "str" has no attribute "itemsize" [attr-defined] +axengine/_axclrt.py:351: error: Value of type "Optional[Any]" is not indexable [index] +axengine/_axclrt.py:361: error: Value of type "Optional[Any]" is not indexable [index] +axengine/_axclrt.py:369: error: Value of type "Optional[Any]" is not indexable [index] +axengine/_axclrt.py:373: error: "str" has no attribute "itemsize" [attr-defined] +axengine/_session.py:110: error: Argument 3 to "AXCLRTSession" has incompatible type "Optional[Sequence[dict[Any, Any]]]"; expected "Optional[dict[Any, Any]]" [arg-type] +axengine/_session.py:114: error: Incompatible types in assignment (expression has type "AXEngineSession", variable has type "Optional[AXCLRTSession]") [assignment] +axengine/_session.py:114: error: Argument 3 to "AXEngineSession" has incompatible type "Optional[Sequence[dict[Any, Any]]]"; expected "Optional[dict[Any, Any]]" [arg-type] +axengine/_session.py:155: error: Item "None" of "Optional[AXCLRTSession]" has no attribute "get_inputs" [union-attr] +axengine/_session.py:166: error: Item "None" of "Optional[AXCLRTSession]" has no attribute "get_outputs" [union-attr] +axengine/_session.py:186: error: Item "None" of "Optional[AXCLRTSession]" has no attribute "run" [union-attr] +Found 14 errors in 5 files (checked 13 source files) diff --git a/mypy_errors.txt b/mypy_errors.txt new file mode 100644 index 0000000..51fdf32 --- /dev/null +++ b/mypy_errors.txt @@ -0,0 +1,53 @@ +pyproject.toml: [mypy]: python_version: Python 3.8 is not supported (must be 3.9 or higher) +axengine/_axe_capi.py:11: error: Library stubs not installed for "cffi" [import-untyped] +axengine/_axe_capi.py:13: error: Bracketed expression "[...]" is not valid as a type [valid-type] +axengine/_axclrt_capi.py:10: error: Library stubs not installed for "cffi" [import-untyped] +axengine/_axclrt_capi.py:10: note: Hint: "python3 -m pip install types-cffi" +axengine/_axclrt_capi.py:10: note: (or run "mypy --install-types" to install all missing stub packages) +axengine/_axclrt_capi.py:10: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports +axengine/_axclrt_capi.py:12: error: Bracketed expression "[...]" is not valid as a type [valid-type] +axengine/_base_session.py:29: error: Need type annotation for "_inputs" (hint: "_inputs: list[] = ...") [var-annotated] +axengine/_base_session.py:30: error: Need type annotation for "_outputs" (hint: "_outputs: list[] = ...") [var-annotated] +axengine/_base_session.py:52: error: Returning Any from function declared to return "list[NodeArg]" [no-any-return] +axengine/_base_session.py:58: error: Returning Any from function declared to return "list[NodeArg]" [no-any-return] +axengine/_axe.py:23: error: Bracketed expression "[...]" is not valid as a type [valid-type] +axengine/_axe.py:23: note: Did you mean "List[...]"? +axengine/_axe.py:99: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_axe.py:100: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_axe.py:101: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_axe.py:303: error: Argument 3 to "NodeArg" has incompatible type "list[Any]"; expected "tuple[int, ...]" [arg-type] +axengine/_axe.py:314: error: Signature of "run" incompatible with supertype "axengine._base_session.Session" [override] +axengine/_axe.py:314: note: Superclass: +axengine/_axe.py:314: note: def run(self, output_names: Optional[list[str]], input_feed: dict[str, ndarray[Any, dtype[Any]]], run_options: Any = ...) -> list[ndarray[Any, dtype[Any]]] +axengine/_axe.py:314: note: Subclass: +axengine/_axe.py:314: note: def run(self, output_names: list[str], input_feed: dict[str, ndarray[Any, dtype[Any]]], run_options: Any = ..., shape_group: int = ...) -> Any +axengine/_axe.py:361: error: "str" has no attribute "itemsize" [attr-defined] +axengine/_axclrt.py:24: error: Bracketed expression "[...]" is not valid as a type [valid-type] +axengine/_axclrt.py:24: note: Did you mean "List[...]"? +axengine/_axclrt.py:28: error: Need type annotation for "_all_model_instances" (hint: "_all_model_instances: list[] = ...") [var-annotated] +axengine/_axclrt.py:72: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_axclrt.py:73: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_axclrt.py:74: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_axclrt.py:318: error: Signature of "run" incompatible with supertype "axengine._base_session.Session" [override] +axengine/_axclrt.py:318: note: Superclass: +axengine/_axclrt.py:318: note: def run(self, output_names: Optional[list[str]], input_feed: dict[str, ndarray[Any, dtype[Any]]], run_options: Any = ...) -> list[ndarray[Any, dtype[Any]]] +axengine/_axclrt.py:318: note: Subclass: +axengine/_axclrt.py:318: note: def run(self, output_names: list[str], input_feed: dict[str, ndarray[Any, dtype[Any]]], run_options: Any = ..., shape_group: int = ...) -> Any +axengine/_axclrt.py:345: error: Value of type "Optional[Any]" is not indexable [index] +axengine/_axclrt.py:355: error: Value of type "Optional[Any]" is not indexable [index] +axengine/_axclrt.py:363: error: Value of type "Optional[Any]" is not indexable [index] +axengine/_axclrt.py:367: error: "str" has no attribute "itemsize" [attr-defined] +axengine/_session.py:36: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_session.py:37: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_session.py:38: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_session.py:39: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_session.py:110: error: Argument 3 to "AXCLRTSession" has incompatible type "Optional[Sequence[dict[Any, Any]]]"; expected "Optional[dict[Any, Any]]" [arg-type] +axengine/_session.py:114: error: Incompatible types in assignment (expression has type "AXEngineSession", variable has type "Optional[AXCLRTSession]") [assignment] +axengine/_session.py:114: error: Argument 3 to "AXEngineSession" has incompatible type "Optional[Sequence[dict[Any, Any]]]"; expected "Optional[dict[Any, Any]]" [arg-type] +axengine/_session.py:155: error: Item "None" of "Optional[AXCLRTSession]" has no attribute "get_inputs" [union-attr] +axengine/_session.py:166: error: Item "None" of "Optional[AXCLRTSession]" has no attribute "get_outputs" [union-attr] +axengine/_session.py:169: error: X | Y syntax for unions requires Python 3.10 [syntax] +axengine/_session.py:182: error: Returning Any from function declared to return "list[ndarray[Any, dtype[Any]]]" [no-any-return] +axengine/_session.py:182: error: Item "None" of "Optional[AXCLRTSession]" has no attribute "run" [union-attr] +axengine/_session.py:182: error: Argument 1 to "run" of "AXCLRTSession" has incompatible type "Optional[list[str]]"; expected "list[str]" [arg-type] +Found 38 errors in 6 files (checked 13 source files) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4c6046d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,70 @@ +[build-system] +requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "axengine" +version = "0.1.3" +description = "Python API for Axera NPU Runtime" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "BSD-3-Clause"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "cffi>=1.0.0", + "ml-dtypes>=0.1.0", + "numpy>=1.22", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov>=4.0", + "ruff>=0.1.0", + "mypy>=1.0", + "griffe>=0.30.0", +] + +[tool.ruff] +line-length = 120 +target-version = "py38" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +markers = [ + "hardware: tests that require AX hardware", + "unit: unit tests that don't require hardware", +] + +[tool.coverage.run] +source = ["axengine"] +omit = ["*/tests/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", +] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..1df8605 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,5 @@ +pytest>=7.0 +pytest-cov>=4.0 +ruff>=0.1.0 +mypy>=1.0 +griffe>=0.30.0 diff --git a/script/README.md b/script/README.md new file mode 100644 index 0000000..7da829c --- /dev/null +++ b/script/README.md @@ -0,0 +1,268 @@ +# 硬件测试手册 + +本文档提供 PyAXEngine 硬件测试的完整说明,包含 4 个核心测试用例。 + +## 硬件要求 + +### 支持的硬件平台 + +- **开发板**: AX650N / AX630C(如爱芯派Pro) +- **算力卡**: AX650 M.2 算力卡 + +### 软件依赖 + +- Python >= 3.8 +- axengine 已安装 +- 测试模型文件(.axmodel 格式) + +### 驱动要求 + +- 开发板: 需要安装 libax_engine.so +- 算力卡: 需要安装 AXCL 驱动,参考 [AXCL 文档](https://axcl-docs.readthedocs.io/zh-cn/latest/) + +--- + +## 测试 1: AxEngine 基础功能测试 + +### 测试目的 + +验证 AxEngineExecutionProvider 在开发板上的基本功能,包括会话创建和模型信息获取。 + +### 硬件要求 + +- AX650N 或 AX630C 开发板 +- 已安装 libax_engine.so + +### 环境配置 + +```bash +# 确认 axengine 已安装 +pip show axengine + +# 准备测试模型 +# 模型路径示例: /opt/data/npu/models/mobilenetv2.axmodel +``` + +### 运行命令 + +```bash +cd script +python3 test_axengine_basic.py <模型路径> +``` + +示例: +```bash +python3 test_axengine_basic.py /opt/data/npu/models/mobilenetv2.axmodel +``` + +### 预期输出 + +``` +[INFO] Testing AxEngine with model: /opt/data/npu/models/mobilenetv2.axmodel +[INFO] Available providers: ['AXCLRTExecutionProvider', 'AxEngineExecutionProvider'] +[INFO] Successfully created session with AxEngineExecutionProvider +[INFO] Model inputs: 1, outputs: 1 +``` + +退出码: 0(成功) + +### 故障排查 + +**问题**: `[ERROR] AxEngineExecutionProvider not available` +- 检查是否在开发板上运行 +- 确认 libax_engine.so 已正确安装 +- 运行 `ldd` 检查库依赖 + +**问题**: `[ERROR] Failed to create session` +- 确认模型文件路径正确 +- 检查模型文件是否为有效的 .axmodel 格式 +- 确认模型与芯片型号匹配(AX650N/AX630C) + +--- + +## 测试 2: AXCLRT 基础功能测试 + +### 测试目的 + +验证 AXCLRTExecutionProvider 在算力卡上的基本功能。 + +### 硬件要求 + +- AX650 M.2 算力卡 +- 已安装 AXCL 驱动 + +### 环境配置 + +```bash +# 确认算力卡已识别 +lspci | grep AXERA + +# 确认 AXCL 驱动已加载 +lsmod | grep axcl + +# 确认 axengine 已安装 +pip show axengine +``` + +### 运行命令 + +```bash +cd script +python3 test_axclrt_basic.py <模型路径> +``` + +示例: +```bash +python3 test_axclrt_basic.py /opt/data/npu/models/mobilenetv2.axmodel +``` + +### 预期输出 + +``` +[INFO] Testing AXCLRT with model: /opt/data/npu/models/mobilenetv2.axmodel +[INFO] Available providers: ['AXCLRTExecutionProvider', 'AxEngineExecutionProvider'] +[INFO] Successfully created session with AXCLRTExecutionProvider +[INFO] Model inputs: 1, outputs: 1 +``` + +退出码: 0(成功) + +### 故障排查 + +**问题**: `[ERROR] AXCLRTExecutionProvider not available` +- 检查算力卡是否正确插入 +- 确认 AXCL 驱动已安装: `lsmod | grep axcl` +- 重新安装驱动或重启系统 + +**问题**: `[ERROR] Failed to create session` +- 检查模型文件路径 +- 确认算力卡有足够内存 +- 查看系统日志: `dmesg | tail` + +--- + +## 测试 3: AxEngine 会话创建集成测试 + +### 测试目的 + +验证 AxEngineExecutionProvider 的会话创建和 provider 查询功能。 + +### 硬件要求 + +- AX650N 或 AX630C 开发板 + +### 环境配置 + +```bash +# 安装 pytest +pip install pytest + +# 准备测试模型 +# 在项目根目录放置 model.axmodel +``` + +### 运行命令 + +```bash +# 在项目根目录运行 +pytest tests/test_integration.py::TestHardwareIntegration::test_axengine_session_creation -v +``` + +### 预期输出 + +``` +tests/test_integration.py::TestHardwareIntegration::test_axengine_session_creation PASSED [100%] +``` + +### 故障排查 + +**问题**: `FileNotFoundError: model.axmodel` +- 在项目根目录创建或链接 model.axmodel +- 使用软链接: `ln -s /opt/data/npu/models/mobilenetv2.axmodel model.axmodel` + +**问题**: 断言失败 `assert sess.get_providers() == 'AxEngineExecutionProvider'` +- 检查 provider 是否正确初始化 +- 确认开发板环境正常 + +--- + +## 测试 4: 推理运行集成测试 + +### 测试目的 + +验证完整的推理流程,包括输入准备、推理执行和输出获取。 + +### 硬件要求 + +- AX650N 或 AX630C 开发板,或 AX650 M.2 算力卡 + +### 环境配置 + +```bash +# 安装依赖 +pip install pytest numpy + +# 准备测试模型 +# 在项目根目录放置 model.axmodel +``` + +### 运行命令 + +```bash +# 在项目根目录运行 +pytest tests/test_integration.py::TestHardwareIntegration::test_inference_run -v +``` + +### 预期输出 + +``` +tests/test_integration.py::TestHardwareIntegration::test_inference_run PASSED [100%] +``` + +### 故障排查 + +**问题**: `FileNotFoundError: model.axmodel` +- 参考测试 3 的解决方案 + +**问题**: 推理失败或输出为空 +- 检查输入数据形状是否与模型匹配 +- 确认模型输入类型为 float32 +- 检查 NPU 内存是否充足 + +**问题**: 数值异常 +- 确认 SDK 版本支持模型的数据类型(bf16 需要 AX650 SDK >= 2.18) +- 检查模型编译版本与运行时版本是否匹配 + +--- + +## 批量运行所有硬件测试 + +```bash +# 运行所有硬件测试 +pytest tests/test_integration.py -m hardware -v + +# 跳过硬件测试(仅运行单元测试) +pytest -m "not hardware" +``` + +## 常见问题 + +### SDK 版本问题 + +AX650 SDK 2.18 和 AX620E SDK 3.12 之前的版本不支持 bf16,可能导致 LLM 模型返回 unknown dtype。 + +解决方案: +- 升级 SDK 到最新版本 +- 或仅更新 libax_engine.so + +### 算力卡 vs 开发板选择 + +- **开发板模式**: 使用 AxEngineExecutionProvider,适合快速原型验证 +- **算力卡模式**: 使用 AXCLRTExecutionProvider,适合生产部署 + +如果主要使用算力卡,建议使用 [pyAXCL](https://github.com/AXERA-TECH/pyaxcl) 获得完整 API 支持。 + +## 技术支持 + +- GitHub Issues: https://github.com/AXERA-TECH/pyaxengine/issues +- QQ 群: 139953715 diff --git a/script/test_axclrt_basic.py b/script/test_axclrt_basic.py new file mode 100644 index 0000000..cdb275f --- /dev/null +++ b/script/test_axclrt_basic.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +import sys + +import axengine as axe + + +def main(model_path): + print(f"[INFO] Testing AXCLRT with model: {model_path}") + + available = axe.get_available_providers() + print(f"[INFO] Available providers: {available}") + + if axe.axclrt_provider_name not in available: + print(f"[ERROR] {axe.axclrt_provider_name} not available") + return False + + try: + session = axe.InferenceSession(model_path, providers=[axe.axclrt_provider_name]) + print(f"[INFO] Successfully created session with {axe.axclrt_provider_name}") + + inputs = session.get_inputs() + outputs = session.get_outputs() + print(f"[INFO] Model inputs: {len(inputs)}, outputs: {len(outputs)}") + + return True + except Exception as e: + print(f"[ERROR] Failed to create session: {e}") + return False + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python test_axclrt_basic.py ") + sys.exit(1) + + success = main(sys.argv[1]) + sys.exit(0 if success else 1) diff --git a/script/test_axengine_basic.py b/script/test_axengine_basic.py new file mode 100644 index 0000000..90ba393 --- /dev/null +++ b/script/test_axengine_basic.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +import sys + +import axengine as axe + + +def main(model_path): + print(f"[INFO] Testing AxEngine with model: {model_path}") + + available = axe.get_available_providers() + print(f"[INFO] Available providers: {available}") + + if axe.axengine_provider_name not in available: + print(f"[ERROR] {axe.axengine_provider_name} not available") + return False + + try: + session = axe.InferenceSession(model_path, providers=[axe.axengine_provider_name]) + print(f"[INFO] Successfully created session with {axe.axengine_provider_name}") + + inputs = session.get_inputs() + outputs = session.get_outputs() + print(f"[INFO] Model inputs: {len(inputs)}, outputs: {len(outputs)}") + + return True + except Exception as e: + print(f"[ERROR] Failed to create session: {e}") + return False + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python test_axengine_basic.py ") + sys.exit(1) + + success = main(sys.argv[1]) + sys.exit(0 if success else 1) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..f769059 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,28 @@ +# PyAXEngine Tests + +## Running Tests + +### All tests +```bash +pytest +``` + +### Unit tests only (no hardware required) +```bash +pytest -m "not hardware" +``` + +### Hardware tests only +```bash +pytest -m hardware +``` + +### With coverage +```bash +pytest --cov=axengine --cov-report=term-missing +``` + +## Test Markers + +- `@pytest.mark.hardware`: Tests requiring AX hardware +- `@pytest.mark.unit`: Unit tests without hardware dependency diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9d99889 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +"""Pytest configuration for axengine tests.""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +sys.modules["axengine._axe_capi"] = MagicMock() +sys.modules["axengine._axclrt_capi"] = MagicMock() + +with patch("ctypes.util.find_library", return_value="libax_engine.so"): + pass + + +def pytest_configure(config): + config.addinivalue_line("markers", "hardware: tests that require AX hardware") + config.addinivalue_line("markers", "unit: unit tests that don't require hardware") + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + monkeypatch.setattr("axengine._providers.providers", ["AxEngineExecutionProvider"]) diff --git a/tests/test_base_session.py b/tests/test_base_session.py new file mode 100644 index 0000000..91cc1f6 --- /dev/null +++ b/tests/test_base_session.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + monkeypatch.setattr("axengine._providers.providers", ["AxEngineExecutionProvider"]) + + +@pytest.mark.unit +class TestSessionOptions: + def test_creation(self): + from axengine import SessionOptions + + opts = SessionOptions() + assert opts is not None + + def test_is_class(self): + from axengine import SessionOptions + + assert isinstance(SessionOptions, type) + + +@pytest.mark.unit +class TestSession: + def test_initialization(self): + from axengine._base_session import Session + + class MockSession(Session): + def run(self, output_names, input_feed, run_options=None): + return [] + + sess = MockSession() + assert sess._shape_count == 0 + assert sess._inputs == [] + assert sess._outputs == [] + + def test_validate_input_success(self): + from axengine._base_session import Session + from axengine._node import NodeArg + + class MockSession(Session): + def run(self, output_names, input_feed, run_options=None): + return [] + + sess = MockSession() + sess._inputs = [[NodeArg("input1", "float32", (1, 3))]] + feed = {"input1": np.array([1, 2, 3])} + sess._validate_input(feed) + + def test_validate_input_missing(self): + from axengine._base_session import Session + from axengine._node import NodeArg + + class MockSession(Session): + def run(self, output_names, input_feed, run_options=None): + return [] + + sess = MockSession() + sess._inputs = [[NodeArg("input1", "float32", (1, 3))]] + feed = {} + with pytest.raises(ValueError, match="Required inputs"): + sess._validate_input(feed) + + def test_validate_output_success(self): + from axengine._base_session import Session + from axengine._node import NodeArg + + class MockSession(Session): + def run(self, output_names, input_feed, run_options=None): + return [] + + sess = MockSession() + sess._outputs = [[NodeArg("output1", "float32", (1, 10))]] + sess._validate_output(["output1"]) + + def test_validate_output_invalid(self): + from axengine._base_session import Session + from axengine._node import NodeArg + + class MockSession(Session): + def run(self, output_names, input_feed, run_options=None): + return [] + + sess = MockSession() + sess._outputs = [[NodeArg("output1", "float32", (1, 10))]] + with pytest.raises(ValueError, match="not in model outputs"): + sess._validate_output(["invalid"]) + + def test_validate_output_none(self): + from axengine._base_session import Session + from axengine._node import NodeArg + + class MockSession(Session): + def run(self, output_names, input_feed, run_options=None): + return [] + + sess = MockSession() + sess._outputs = [[NodeArg("output1", "float32", (1, 10))]] + sess._validate_output(None) diff --git a/tests/test_characterization.py b/tests/test_characterization.py new file mode 100644 index 0000000..a5bb177 --- /dev/null +++ b/tests/test_characterization.py @@ -0,0 +1,61 @@ +""" +Characterization tests for PyAXEngine - capturing current behavior. + +These tests document the current state of the API, including any bugs. +DO NOT fix bugs here - just record what currently happens. +""" +import pytest + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + """Mock providers to allow import without hardware.""" + monkeypatch.setattr('axengine._providers.providers', ['AxEngineExecutionProvider']) + + +def test_axengine_imports(): + """Test that basic axengine imports work.""" + import axengine + from axengine import InferenceSession, NodeArg, SessionOptions + assert axengine is not None + assert InferenceSession is not None + assert NodeArg is not None + assert SessionOptions is not None + + +def test_node_arg_creation(): + """Test NodeArg can be instantiated.""" + from axengine import NodeArg + node = NodeArg(name="test", dtype="float32", shape=(1, 3, 224, 224)) + assert node.name == "test" + assert node.dtype == "float32" + assert node.shape == (1, 3, 224, 224) + + +def test_session_options_creation(): + """Test SessionOptions can be instantiated.""" + from axengine import SessionOptions + opts = SessionOptions() + assert opts is not None + + +def test_inference_session_signature(): + """Test InferenceSession has expected __init__ signature.""" + import inspect + + from axengine import InferenceSession + sig = inspect.signature(InferenceSession.__init__) + params = list(sig.parameters.keys()) + assert 'self' in params + assert 'path_or_bytes' in params + assert 'sess_options' in params + assert 'providers' in params + + +def test_node_arg_attributes(): + """Test NodeArg has expected attributes.""" + from axengine import NodeArg + node = NodeArg(name="input", dtype="uint8", shape=(1, 224, 224, 3)) + assert hasattr(node, 'name') + assert hasattr(node, 'dtype') + assert hasattr(node, 'shape') diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..38ad2a5 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest + + +@pytest.mark.hardware +class TestHardwareIntegration: + + def test_axengine_session_creation(self): + from axengine import InferenceSession + sess = InferenceSession('model.axmodel', providers='AxEngineExecutionProvider') + assert sess is not None + assert sess.get_providers() == 'AxEngineExecutionProvider' + + def test_axclrt_session_creation(self): + from axengine import InferenceSession + sess = InferenceSession('model.axmodel', providers='AXCLRTExecutionProvider') + assert sess is not None + assert sess.get_providers() == 'AXCLRTExecutionProvider' + + def test_inference_run(self): + from axengine import InferenceSession + sess = InferenceSession('model.axmodel') + inputs = sess.get_inputs() + input_data = {inputs[0].name: np.random.randn(*inputs[0].shape).astype(np.float32)} + outputs = sess.run(None, input_data) + assert len(outputs) > 0 + + def test_context_manager(self): + from axengine import InferenceSession + with InferenceSession('model.axmodel') as sess: + inputs = sess.get_inputs() + assert len(inputs) > 0 diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..7edcfd8 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,39 @@ +import logging + +import pytest + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + monkeypatch.setattr('axengine._providers.providers', ['AxEngineExecutionProvider']) + + +@pytest.mark.unit +class TestLogging: + + def test_get_logger_returns_logger(self): + from axengine._logging import get_logger + logger = get_logger("test") + assert isinstance(logger, logging.Logger) + + def test_logger_has_handler(self): + from axengine._logging import get_logger + logger = get_logger("test_handler") + assert len(logger.handlers) > 0 + + def test_logger_default_level(self, monkeypatch): + from axengine._logging import get_logger + monkeypatch.delenv('AXENGINE_LOG_LEVEL', raising=False) + logger = get_logger("test_default") + assert logger.level == logging.INFO + + def test_logger_custom_level(self, monkeypatch): + from axengine._logging import get_logger + monkeypatch.setenv('AXENGINE_LOG_LEVEL', 'DEBUG') + logger = get_logger("test_debug") + assert logger.level == logging.DEBUG + + def test_logger_name(self): + from axengine._logging import get_logger + logger = get_logger("my.module") + assert logger.name == "my.module" diff --git a/tests/test_node.py b/tests/test_node.py new file mode 100644 index 0000000..1c38fe8 --- /dev/null +++ b/tests/test_node.py @@ -0,0 +1,50 @@ +import pytest + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + monkeypatch.setattr("axengine._providers.providers", ["AxEngineExecutionProvider"]) + + +@pytest.mark.unit +class TestNodeArg: + def test_creation(self): + from axengine import NodeArg + + node = NodeArg(name="input", dtype="float32", shape=(1, 3, 224, 224)) + assert node.name == "input" + assert node.dtype == "float32" + assert node.shape == (1, 3, 224, 224) + + def test_different_dtypes(self): + from axengine import NodeArg + + dtypes = ["uint8", "int8", "uint16", "int16", "uint32", "int32", "float32", "bfloat16"] + for dtype in dtypes: + node = NodeArg(name="test", dtype=dtype, shape=(1,)) + assert node.dtype == dtype + + def test_different_shapes(self): + from axengine import NodeArg + + shapes = [(1,), (1, 3), (1, 3, 224), (1, 3, 224, 224), (2, 4, 8, 16, 32)] + for shape in shapes: + node = NodeArg(name="test", dtype="float32", shape=shape) + assert node.shape == shape + + def test_empty_name(self): + from axengine import NodeArg + + node = NodeArg(name="", dtype="float32", shape=(1,)) + assert node.name == "" + + def test_attributes_mutable(self): + from axengine import NodeArg + + node = NodeArg(name="input", dtype="float32", shape=(1, 3, 224, 224)) + node.name = "output" + node.dtype = "int8" + node.shape = (1, 10) + assert node.name == "output" + assert node.dtype == "int8" + assert node.shape == (1, 10) diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 0000000..560365f --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,28 @@ +import pytest + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + monkeypatch.setattr('axengine._providers.providers', ['AxEngineExecutionProvider']) + + +@pytest.mark.unit +class TestProviders: + + def test_get_all_providers(self): + from axengine._providers import axclrt_provider_name, axengine_provider_name, get_all_providers + providers = get_all_providers() + assert isinstance(providers, list) + assert axengine_provider_name in providers + assert axclrt_provider_name in providers + assert len(providers) == 2 + + def test_get_available_providers(self): + from axengine._providers import get_available_providers + providers = get_available_providers() + assert isinstance(providers, list) + + def test_provider_names(self): + from axengine._providers import axclrt_provider_name, axengine_provider_name + assert axengine_provider_name == "AxEngineExecutionProvider" + assert axclrt_provider_name == "AXCLRTExecutionProvider" diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..6d1f623 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,42 @@ +from unittest.mock import patch + +import pytest + + +@pytest.fixture(autouse=True) +def mock_providers(monkeypatch): + monkeypatch.setattr('axengine._providers.providers', ['AxEngineExecutionProvider']) + + +@pytest.mark.unit +class TestInferenceSession: + + def test_init_unavailable_provider(self): + from axengine import InferenceSession + with patch('axengine._session.get_available_providers', return_value=['AxEngineExecutionProvider']): + with pytest.raises(ValueError, match="not available"): + InferenceSession('model.axmodel', providers='InvalidProvider') + + def test_init_invalid_provider_type(self): + from axengine import InferenceSession + with patch('axengine._session.get_available_providers', return_value=['AxEngineExecutionProvider']): + with pytest.raises(TypeError, match="Invalid provider type"): + InferenceSession('model.axmodel', providers=[123]) + + def test_init_invalid_tuple_length(self): + from axengine import InferenceSession + with patch('axengine._session.get_available_providers', return_value=['AxEngineExecutionProvider']): + with pytest.raises(ValueError, match="tuple with 2 elements"): + InferenceSession('model.axmodel', providers=[('Provider',)]) + + def test_init_invalid_tuple_name_type(self): + from axengine import InferenceSession + with patch('axengine._session.get_available_providers', return_value=['AxEngineExecutionProvider']): + with pytest.raises(TypeError): + InferenceSession('model.axmodel', providers=[(123, {})]) + + def test_init_invalid_tuple_dict_type(self): + from axengine import InferenceSession + with patch('axengine._session.get_available_providers', return_value=['AxEngineExecutionProvider']): + with pytest.raises(TypeError): + InferenceSession('model.axmodel', providers=[('Provider', 'not_dict')])