diff --git a/src/ebonite/core/objects/wrapper.py b/src/ebonite/core/objects/wrapper.py index 2ad9d490..8b83739c 100644 --- a/src/ebonite/core/objects/wrapper.py +++ b/src/ebonite/core/objects/wrapper.py @@ -193,7 +193,6 @@ def _exposed_methods_mapping(self) -> typing.Dict[str, str]: this allows to wrap existing API with your own pre/postprocessing. Otherwise, wrapped model object method is going to be called. """ - pass # pragma: no cover @staticmethod def with_model(f): diff --git a/src/ebonite/ext/ext_loader.py b/src/ebonite/ext/ext_loader.py index 07c6c95b..c500418d 100644 --- a/src/ebonite/ext/ext_loader.py +++ b/src/ebonite/ext/ext_loader.py @@ -77,7 +77,8 @@ class ExtensionLoader: Extension('ebonite.ext.imageio', ['imageio']), Extension('ebonite.ext.lightgbm', ['lightgbm'], False), Extension('ebonite.ext.xgboost', ['xgboost'], False), - Extension('ebonite.ext.docker', ['docker'], False) + Extension('ebonite.ext.docker', ['docker'], False), + Extension('ebonite.ext.onnx', ['onnx'], False) ) _loaded_extensions: Dict[Extension, ModuleType] = {} diff --git a/src/ebonite/ext/onnx/__init__.py b/src/ebonite/ext/onnx/__init__.py new file mode 100644 index 00000000..32358664 --- /dev/null +++ b/src/ebonite/ext/onnx/__init__.py @@ -0,0 +1,3 @@ +from .model import ONNXModelHook, ONNXModelWrapper + +__all__ = ['ONNXModelWrapper', 'ONNXModelHook'] diff --git a/src/ebonite/ext/onnx/model.py b/src/ebonite/ext/onnx/model.py new file mode 100644 index 00000000..90edde9c --- /dev/null +++ b/src/ebonite/ext/onnx/model.py @@ -0,0 +1,107 @@ +import contextlib +import os +from abc import abstractmethod +from typing import Dict, List, Type, Union + +import onnx +from onnx import ModelProto +from pyjackson.decorators import cached_property + +from ebonite.core.analyzer import TypeHookMixin +from ebonite.core.analyzer.model import BindingModelHook +from ebonite.core.objects import ModelIO, ModelWrapper, Requirements +from ebonite.core.objects.artifacts import Blobs, InMemoryBlob +from ebonite.core.objects.wrapper import FilesContextManager +from ebonite.utils.abc_utils import is_abstract_method +from ebonite.utils.importing import module_importable + +_DEFAULT_BACKEND = 'onnxruntime' + + +def set_default_onnx_backend(backend: Union[Type['ONNXInferenceBackend'], str]): + global _DEFAULT_BACKEND + if not isinstance(backend, str): + backend = backend.name + if backend not in ONNXInferenceBackend.subtypes: + raise ValueError(f'unknown onnx backend {backend}') + _DEFAULT_BACKEND = backend + + +class ONNXModelIO(ModelIO): + FILENAME = 'model.onnx' + + @contextlib.contextmanager + def dump(self, model: ModelProto) -> FilesContextManager: + yield Blobs({self.FILENAME: InMemoryBlob(model.SerializeToString())}) # TODO change to LazyBlob + + def load(self, path): + return onnx.load(os.path.join(path, self.FILENAME)) + + +class ONNXInferenceBackend: + subtypes: Dict[str, Type['ONNXInferenceBackend']] = {} + name: str + requirements: List[str] + + def __init__(self, wrapper: 'ONNXModelWrapper'): + self.wrapper = wrapper + + def __init_subclass__(cls, **kwargs): + if not is_abstract_method(cls.run) and not is_abstract_method(cls.is_available): + if 'name' not in cls.__dict__ or 'requirements' not in cls.__dict__: + raise AttributeError(f'provide name and requirements fields for {cls}') + ONNXInferenceBackend.subtypes[cls.name] = cls + super().__init_subclass__(**kwargs) + + @abstractmethod + def run(self, data): + """""" + + @classmethod + def is_available(cls): + return all(module_importable(m) for m in cls.requirements) + + +class ONNXRuntimeBackend(ONNXInferenceBackend): + name = 'onnxruntime' + requirements = ['onnxruntime'] + _session = None + + @cached_property + def session(self): + import onnxruntime as rt + if self._session is None: + self._session = rt.InferenceSession(self.wrapper.model.SerializeToString()) + return self._session + + def run(self, data): + return self.session.run([o.name for o in self.session.get_outputs()], data) + + +class ONNXModelWrapper(ModelWrapper): + model: ModelProto + + def __init__(self, io: ModelIO, backend: str = None): + super().__init__(io) + self.backend = backend or _DEFAULT_BACKEND + if self.backend not in ONNXInferenceBackend.subtypes: + raise ValueError(f'unknown onnx backend {self.backend}') + self._backend: ONNXInferenceBackend = ONNXInferenceBackend.subtypes[self.backend](self) + + def run(self, data): + if not self._backend.is_available(): + raise RuntimeError(f'{self.backend} inference backend is unavailable') + return self._backend.run(data) + + def _exposed_methods_mapping(self) -> Dict[str, str]: + return {'predict': 'run'} + + def _model_requirements(self) -> Requirements: + return super(ONNXModelWrapper, self)._model_requirements() + self._backend.requirements + + +class ONNXModelHook(BindingModelHook, TypeHookMixin): + valid_types = [ModelProto] + + def _wrapper_factory(self) -> ModelWrapper: + return ONNXModelWrapper(ONNXModelIO()) diff --git a/test.requirements.txt b/test.requirements.txt index 9ce73b65..2b026e90 100644 --- a/test.requirements.txt +++ b/test.requirements.txt @@ -18,6 +18,8 @@ testcontainers==2.6.0 pytest==5.2.2 xgboost==1.0.2 lightgbm==2.3.1 +onnx==1.7.0 +onnxruntime==1.4.0 torch==1.4.0+cpu ; sys_platform != "darwin" diff --git a/tests/ext/test_onnx/__init__.py b/tests/ext/test_onnx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/test_onnx/test_model.py b/tests/ext/test_onnx/test_model.py new file mode 100644 index 00000000..513b8ad2 --- /dev/null +++ b/tests/ext/test_onnx/test_model.py @@ -0,0 +1,75 @@ +import numpy as np +import onnx +import pytest +from onnxruntime.datasets import get_example +from pyjackson import deserialize, serialize + +from ebonite.core.analyzer.model import ModelAnalyzer +from ebonite.core.objects import ModelWrapper +from ebonite.ext.onnx.model import ONNXInferenceBackend, ONNXModelWrapper + + +@pytest.fixture +def onnx_model(): + return onnx.load(get_example('sigmoid.onnx')) + + +@pytest.fixture +def onnx_input(): + x = np.random.random((3, 4, 5)).astype(np.float32) + return {'x': x} + + +@pytest.fixture +def onnx_wrapper(onnx_model, onnx_input): + return ModelAnalyzer.analyze(onnx_model, input_data=onnx_input) + + +def test_onnx_hook(onnx_wrapper): + assert isinstance(onnx_wrapper, ONNXModelWrapper) + + +def test_onnx_io(onnx_wrapper: ModelWrapper, tmpdir, onnx_input): + with onnx_wrapper.dump() as artifacts: + artifacts.materialize(tmpdir) + + onnx_wrapper: ONNXModelWrapper = deserialize(serialize(onnx_wrapper), ModelWrapper) + assert isinstance(onnx_wrapper, ONNXModelWrapper) + onnx_wrapper.load(tmpdir) + predict = onnx_wrapper.run(onnx_input) + assert isinstance(predict, list) + assert len(predict) == 1 + tensor = predict[0] + assert isinstance(tensor, np.ndarray) + + +def test_onnx_backend_subclass(): + class AbstractSubclass(ONNXInferenceBackend): + name = 'abstract' + requirements = [] + + assert AbstractSubclass.name not in ONNXInferenceBackend.subtypes + + with pytest.raises(AttributeError): + class NoName(ONNXInferenceBackend): + requirements = [] + + def run(self, data): + pass + + with pytest.raises(AttributeError): + class NoReqs(ONNXInferenceBackend): + name = 'aaa' + + def run(self, data): + pass + + class Good(ONNXInferenceBackend): + name = 'aaa' + requirements = [] + + def run(self, data): + pass + + assert Good.name in ONNXInferenceBackend.subtypes + del ONNXInferenceBackend.subtypes[Good.name]