From 1b503e0b26ef1290f384fc10f022d51055843648 Mon Sep 17 00:00:00 2001 From: mike0sv Date: Wed, 19 Aug 2020 23:07:17 +0300 Subject: [PATCH 1/4] EBNT-284 custom metrics --- src/ebonite/core/objects/wrapper.py | 1 - src/ebonite/ext/ext_loader.py | 3 +- src/ebonite/ext/onnx/__init__.py | 3 + src/ebonite/ext/onnx/model.py | 107 ++++++++++++++++++++++++++++ test.requirements.txt | 2 + tests/ext/test_onnx/__init__.py | 0 tests/ext/test_onnx/model.py | 43 +++++++++++ 7 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 src/ebonite/ext/onnx/__init__.py create mode 100644 src/ebonite/ext/onnx/model.py create mode 100644 tests/ext/test_onnx/__init__.py create mode 100644 tests/ext/test_onnx/model.py diff --git a/src/ebonite/core/objects/wrapper.py b/src/ebonite/core/objects/wrapper.py index 2ad9d490..8b83739c 100644 --- a/src/ebonite/core/objects/wrapper.py +++ b/src/ebonite/core/objects/wrapper.py @@ -193,7 +193,6 @@ def _exposed_methods_mapping(self) -> typing.Dict[str, str]: this allows to wrap existing API with your own pre/postprocessing. Otherwise, wrapped model object method is going to be called. """ - pass # pragma: no cover @staticmethod def with_model(f): diff --git a/src/ebonite/ext/ext_loader.py b/src/ebonite/ext/ext_loader.py index 07c6c95b..c500418d 100644 --- a/src/ebonite/ext/ext_loader.py +++ b/src/ebonite/ext/ext_loader.py @@ -77,7 +77,8 @@ class ExtensionLoader: Extension('ebonite.ext.imageio', ['imageio']), Extension('ebonite.ext.lightgbm', ['lightgbm'], False), Extension('ebonite.ext.xgboost', ['xgboost'], False), - Extension('ebonite.ext.docker', ['docker'], False) + Extension('ebonite.ext.docker', ['docker'], False), + Extension('ebonite.ext.onnx', ['onnx'], False) ) _loaded_extensions: Dict[Extension, ModuleType] = {} diff --git a/src/ebonite/ext/onnx/__init__.py b/src/ebonite/ext/onnx/__init__.py new file mode 100644 index 00000000..32358664 --- /dev/null +++ b/src/ebonite/ext/onnx/__init__.py @@ -0,0 +1,3 @@ +from .model import ONNXModelHook, ONNXModelWrapper + +__all__ = ['ONNXModelWrapper', 'ONNXModelHook'] diff --git a/src/ebonite/ext/onnx/model.py b/src/ebonite/ext/onnx/model.py new file mode 100644 index 00000000..90edde9c --- /dev/null +++ b/src/ebonite/ext/onnx/model.py @@ -0,0 +1,107 @@ +import contextlib +import os +from abc import abstractmethod +from typing import Dict, List, Type, Union + +import onnx +from onnx import ModelProto +from pyjackson.decorators import cached_property + +from ebonite.core.analyzer import TypeHookMixin +from ebonite.core.analyzer.model import BindingModelHook +from ebonite.core.objects import ModelIO, ModelWrapper, Requirements +from ebonite.core.objects.artifacts import Blobs, InMemoryBlob +from ebonite.core.objects.wrapper import FilesContextManager +from ebonite.utils.abc_utils import is_abstract_method +from ebonite.utils.importing import module_importable + +_DEFAULT_BACKEND = 'onnxruntime' + + +def set_default_onnx_backend(backend: Union[Type['ONNXInferenceBackend'], str]): + global _DEFAULT_BACKEND + if not isinstance(backend, str): + backend = backend.name + if backend not in ONNXInferenceBackend.subtypes: + raise ValueError(f'unknown onnx backend {backend}') + _DEFAULT_BACKEND = backend + + +class ONNXModelIO(ModelIO): + FILENAME = 'model.onnx' + + @contextlib.contextmanager + def dump(self, model: ModelProto) -> FilesContextManager: + yield Blobs({self.FILENAME: InMemoryBlob(model.SerializeToString())}) # TODO change to LazyBlob + + def load(self, path): + return onnx.load(os.path.join(path, self.FILENAME)) + + +class ONNXInferenceBackend: + subtypes: Dict[str, Type['ONNXInferenceBackend']] = {} + name: str + requirements: List[str] + + def __init__(self, wrapper: 'ONNXModelWrapper'): + self.wrapper = wrapper + + def __init_subclass__(cls, **kwargs): + if not is_abstract_method(cls.run) and not is_abstract_method(cls.is_available): + if 'name' not in cls.__dict__ or 'requirements' not in cls.__dict__: + raise AttributeError(f'provide name and requirements fields for {cls}') + ONNXInferenceBackend.subtypes[cls.name] = cls + super().__init_subclass__(**kwargs) + + @abstractmethod + def run(self, data): + """""" + + @classmethod + def is_available(cls): + return all(module_importable(m) for m in cls.requirements) + + +class ONNXRuntimeBackend(ONNXInferenceBackend): + name = 'onnxruntime' + requirements = ['onnxruntime'] + _session = None + + @cached_property + def session(self): + import onnxruntime as rt + if self._session is None: + self._session = rt.InferenceSession(self.wrapper.model.SerializeToString()) + return self._session + + def run(self, data): + return self.session.run([o.name for o in self.session.get_outputs()], data) + + +class ONNXModelWrapper(ModelWrapper): + model: ModelProto + + def __init__(self, io: ModelIO, backend: str = None): + super().__init__(io) + self.backend = backend or _DEFAULT_BACKEND + if self.backend not in ONNXInferenceBackend.subtypes: + raise ValueError(f'unknown onnx backend {self.backend}') + self._backend: ONNXInferenceBackend = ONNXInferenceBackend.subtypes[self.backend](self) + + def run(self, data): + if not self._backend.is_available(): + raise RuntimeError(f'{self.backend} inference backend is unavailable') + return self._backend.run(data) + + def _exposed_methods_mapping(self) -> Dict[str, str]: + return {'predict': 'run'} + + def _model_requirements(self) -> Requirements: + return super(ONNXModelWrapper, self)._model_requirements() + self._backend.requirements + + +class ONNXModelHook(BindingModelHook, TypeHookMixin): + valid_types = [ModelProto] + + def _wrapper_factory(self) -> ModelWrapper: + return ONNXModelWrapper(ONNXModelIO()) diff --git a/test.requirements.txt b/test.requirements.txt index 9ce73b65..2b026e90 100644 --- a/test.requirements.txt +++ b/test.requirements.txt @@ -18,6 +18,8 @@ testcontainers==2.6.0 pytest==5.2.2 xgboost==1.0.2 lightgbm==2.3.1 +onnx==1.7.0 +onnxruntime==1.4.0 torch==1.4.0+cpu ; sys_platform != "darwin" diff --git a/tests/ext/test_onnx/__init__.py b/tests/ext/test_onnx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/test_onnx/model.py b/tests/ext/test_onnx/model.py new file mode 100644 index 00000000..43ed35f4 --- /dev/null +++ b/tests/ext/test_onnx/model.py @@ -0,0 +1,43 @@ +import numpy as np +import onnx +import pytest +from onnxruntime.datasets import get_example +from pyjackson import deserialize, serialize + +from ebonite.core.analyzer.model import ModelAnalyzer +from ebonite.core.objects import ModelWrapper +from ebonite.ext.onnx.model import ONNXModelWrapper + + +@pytest.fixture +def onnx_model(): + return onnx.load(get_example('sigmoid.onnx')) + + +@pytest.fixture +def onnx_input(): + x = np.random.random((3, 4, 5)).astype(np.float32) + return {'x': x} + + +@pytest.fixture +def onnx_wrapper(onnx_model, onnx_input): + return ModelAnalyzer.analyze(onnx_model, input_data=onnx_input) + + +def test_onnx_hook(onnx_wrapper): + assert isinstance(onnx_wrapper, ONNXModelWrapper) + + +def test_onnx_io(onnx_wrapper: ModelWrapper, tmpdir, onnx_input): + with onnx_wrapper.dump() as artifacts: + artifacts.materialize(tmpdir) + + onnx_wrapper: ONNXModelWrapper = deserialize(serialize(onnx_wrapper), ModelWrapper) + assert isinstance(onnx_wrapper, ONNXModelWrapper) + onnx_wrapper.load(tmpdir) + predict = onnx_wrapper.run(onnx_input) + assert isinstance(predict, list) + assert len(predict) == 1 + tensor = predict[0] + assert isinstance(tensor, np.ndarray) From 0571ba3e93f707afb94504a807968f546a78c738 Mon Sep 17 00:00:00 2001 From: mike0sv Date: Mon, 24 Aug 2020 10:04:01 +0300 Subject: [PATCH 2/4] EBNT-422 more tests --- tests/ext/test_onnx/model.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/ext/test_onnx/model.py b/tests/ext/test_onnx/model.py index 43ed35f4..22cc79c8 100644 --- a/tests/ext/test_onnx/model.py +++ b/tests/ext/test_onnx/model.py @@ -6,7 +6,7 @@ from ebonite.core.analyzer.model import ModelAnalyzer from ebonite.core.objects import ModelWrapper -from ebonite.ext.onnx.model import ONNXModelWrapper +from ebonite.ext.onnx.model import ONNXModelWrapper, ONNXInferenceBackend @pytest.fixture @@ -41,3 +41,35 @@ def test_onnx_io(onnx_wrapper: ModelWrapper, tmpdir, onnx_input): assert len(predict) == 1 tensor = predict[0] assert isinstance(tensor, np.ndarray) + + +def test_onnx_backend_subclass(): + class AbstractSubclass(ONNXInferenceBackend): + name = 'abstract' + requirements = [] + + assert AbstractSubclass.name not in ONNXInferenceBackend.subtypes + + with pytest.raises(AttributeError): + class NoName(ONNXInferenceBackend): + requirements = [] + + def run(self, data): + pass + + with pytest.raises(AttributeError): + class NoReqs(ONNXInferenceBackend): + name = 'aaa' + + def run(self, data): + pass + + class Good(ONNXInferenceBackend): + name = 'aaa' + requirements = [] + + def run(self, data): + pass + + assert Good.name in ONNXInferenceBackend.subtypes + del ONNXInferenceBackend.subtypes[Good.name] From 9e96a47781ac0be6c878a44f024c606f3a55bc69 Mon Sep 17 00:00:00 2001 From: mike0sv Date: Mon, 24 Aug 2020 10:05:03 +0300 Subject: [PATCH 3/4] EBNT-422 more tests --- tests/ext/test_onnx/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ext/test_onnx/model.py b/tests/ext/test_onnx/model.py index 22cc79c8..513b8ad2 100644 --- a/tests/ext/test_onnx/model.py +++ b/tests/ext/test_onnx/model.py @@ -6,7 +6,7 @@ from ebonite.core.analyzer.model import ModelAnalyzer from ebonite.core.objects import ModelWrapper -from ebonite.ext.onnx.model import ONNXModelWrapper, ONNXInferenceBackend +from ebonite.ext.onnx.model import ONNXInferenceBackend, ONNXModelWrapper @pytest.fixture From 7b4a14ed713f35a4a293f5fb35e5d385fe043448 Mon Sep 17 00:00:00 2001 From: mike0sv Date: Wed, 26 Aug 2020 16:35:43 +0300 Subject: [PATCH 4/4] EBNT-342 fix tests --- tests/ext/test_onnx/{model.py => test_model.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/ext/test_onnx/{model.py => test_model.py} (100%) diff --git a/tests/ext/test_onnx/model.py b/tests/ext/test_onnx/test_model.py similarity index 100% rename from tests/ext/test_onnx/model.py rename to tests/ext/test_onnx/test_model.py