-
Notifications
You must be signed in to change notification settings - Fork 10
EBNT-422 onnx #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
EBNT-422 onnx #128
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .model import ONNXModelHook, ONNXModelWrapper | ||
|
|
||
| __all__ = ['ONNXModelWrapper', 'ONNXModelHook'] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| import contextlib | ||
| import os | ||
| from abc import abstractmethod | ||
| from typing import Dict, List, Type, Union | ||
|
|
||
| import onnx | ||
| from onnx import ModelProto | ||
| from pyjackson.decorators import cached_property | ||
|
|
||
| from ebonite.core.analyzer import TypeHookMixin | ||
| from ebonite.core.analyzer.model import BindingModelHook | ||
| from ebonite.core.objects import ModelIO, ModelWrapper, Requirements | ||
| from ebonite.core.objects.artifacts import Blobs, InMemoryBlob | ||
| from ebonite.core.objects.wrapper import FilesContextManager | ||
| from ebonite.utils.abc_utils import is_abstract_method | ||
| from ebonite.utils.importing import module_importable | ||
|
|
||
| _DEFAULT_BACKEND = 'onnxruntime' | ||
|
|
||
|
|
||
| def set_default_onnx_backend(backend: Union[Type['ONNXInferenceBackend'], str]): | ||
| global _DEFAULT_BACKEND | ||
| if not isinstance(backend, str): | ||
| backend = backend.name | ||
| if backend not in ONNXInferenceBackend.subtypes: | ||
| raise ValueError(f'unknown onnx backend {backend}') | ||
| _DEFAULT_BACKEND = backend | ||
|
|
||
|
|
||
| class ONNXModelIO(ModelIO): | ||
| FILENAME = 'model.onnx' | ||
|
|
||
| @contextlib.contextmanager | ||
| def dump(self, model: ModelProto) -> FilesContextManager: | ||
| yield Blobs({self.FILENAME: InMemoryBlob(model.SerializeToString())}) # TODO change to LazyBlob | ||
|
|
||
| def load(self, path): | ||
| return onnx.load(os.path.join(path, self.FILENAME)) | ||
|
|
||
|
|
||
| class ONNXInferenceBackend: | ||
| subtypes: Dict[str, Type['ONNXInferenceBackend']] = {} | ||
| name: str | ||
| requirements: List[str] | ||
|
|
||
| def __init__(self, wrapper: 'ONNXModelWrapper'): | ||
| self.wrapper = wrapper | ||
|
|
||
| def __init_subclass__(cls, **kwargs): | ||
| if not is_abstract_method(cls.run) and not is_abstract_method(cls.is_available): | ||
|
mike0sv marked this conversation as resolved.
|
||
| if 'name' not in cls.__dict__ or 'requirements' not in cls.__dict__: | ||
| raise AttributeError(f'provide name and requirements fields for {cls}') | ||
| ONNXInferenceBackend.subtypes[cls.name] = cls | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
| @abstractmethod | ||
| def run(self, data): | ||
| """""" | ||
|
|
||
| @classmethod | ||
| def is_available(cls): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test what? it's one line and calls one function, which is tested separately. and it will be covered in report since it will be called within other tests
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://coveralls.io/builds/32948660/source?filename=src%2Febonite%2Fext%2Fonnx%2Fmodel.py Coveralls seem to have other thoughts about that |
||
| return all(module_importable(m) for m in cls.requirements) | ||
|
|
||
|
|
||
| class ONNXRuntimeBackend(ONNXInferenceBackend): | ||
| name = 'onnxruntime' | ||
| requirements = ['onnxruntime'] | ||
| _session = None | ||
|
|
||
| @cached_property | ||
| def session(self): | ||
| import onnxruntime as rt | ||
| if self._session is None: | ||
|
mike0sv marked this conversation as resolved.
|
||
| self._session = rt.InferenceSession(self.wrapper.model.SerializeToString()) | ||
| return self._session | ||
|
|
||
| def run(self, data): | ||
| return self.session.run([o.name for o in self.session.get_outputs()], data) | ||
|
|
||
|
|
||
| class ONNXModelWrapper(ModelWrapper): | ||
| model: ModelProto | ||
|
|
||
| def __init__(self, io: ModelIO, backend: str = None): | ||
| super().__init__(io) | ||
| self.backend = backend or _DEFAULT_BACKEND | ||
| if self.backend not in ONNXInferenceBackend.subtypes: | ||
| raise ValueError(f'unknown onnx backend {self.backend}') | ||
| self._backend: ONNXInferenceBackend = ONNXInferenceBackend.subtypes[self.backend](self) | ||
|
|
||
| def run(self, data): | ||
| if not self._backend.is_available(): | ||
| raise RuntimeError(f'{self.backend} inference backend is unavailable') | ||
| return self._backend.run(data) | ||
|
|
||
| def _exposed_methods_mapping(self) -> Dict[str, str]: | ||
| return {'predict': 'run'} | ||
|
|
||
| def _model_requirements(self) -> Requirements: | ||
| return super(ONNXModelWrapper, self)._model_requirements() + self._backend.requirements | ||
|
|
||
|
|
||
| class ONNXModelHook(BindingModelHook, TypeHookMixin): | ||
| valid_types = [ModelProto] | ||
|
|
||
| def _wrapper_factory(self) -> ModelWrapper: | ||
| return ONNXModelWrapper(ONNXModelIO()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| import numpy as np | ||
| import onnx | ||
| import pytest | ||
| from onnxruntime.datasets import get_example | ||
| from pyjackson import deserialize, serialize | ||
|
|
||
| from ebonite.core.analyzer.model import ModelAnalyzer | ||
| from ebonite.core.objects import ModelWrapper | ||
| from ebonite.ext.onnx.model import ONNXInferenceBackend, ONNXModelWrapper | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def onnx_model(): | ||
| return onnx.load(get_example('sigmoid.onnx')) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def onnx_input(): | ||
| x = np.random.random((3, 4, 5)).astype(np.float32) | ||
| return {'x': x} | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def onnx_wrapper(onnx_model, onnx_input): | ||
| return ModelAnalyzer.analyze(onnx_model, input_data=onnx_input) | ||
|
|
||
|
|
||
| def test_onnx_hook(onnx_wrapper): | ||
| assert isinstance(onnx_wrapper, ONNXModelWrapper) | ||
|
|
||
|
|
||
| def test_onnx_io(onnx_wrapper: ModelWrapper, tmpdir, onnx_input): | ||
| with onnx_wrapper.dump() as artifacts: | ||
| artifacts.materialize(tmpdir) | ||
|
|
||
| onnx_wrapper: ONNXModelWrapper = deserialize(serialize(onnx_wrapper), ModelWrapper) | ||
| assert isinstance(onnx_wrapper, ONNXModelWrapper) | ||
| onnx_wrapper.load(tmpdir) | ||
| predict = onnx_wrapper.run(onnx_input) | ||
| assert isinstance(predict, list) | ||
| assert len(predict) == 1 | ||
| tensor = predict[0] | ||
| assert isinstance(tensor, np.ndarray) | ||
|
|
||
|
|
||
| def test_onnx_backend_subclass(): | ||
| class AbstractSubclass(ONNXInferenceBackend): | ||
| name = 'abstract' | ||
| requirements = [] | ||
|
|
||
| assert AbstractSubclass.name not in ONNXInferenceBackend.subtypes | ||
|
|
||
| with pytest.raises(AttributeError): | ||
| class NoName(ONNXInferenceBackend): | ||
| requirements = [] | ||
|
|
||
| def run(self, data): | ||
| pass | ||
|
|
||
| with pytest.raises(AttributeError): | ||
| class NoReqs(ONNXInferenceBackend): | ||
| name = 'aaa' | ||
|
|
||
| def run(self, data): | ||
| pass | ||
|
|
||
| class Good(ONNXInferenceBackend): | ||
| name = 'aaa' | ||
| requirements = [] | ||
|
|
||
| def run(self, data): | ||
| pass | ||
|
|
||
| assert Good.name in ONNXInferenceBackend.subtypes | ||
| del ONNXInferenceBackend.subtypes[Good.name] |
Uh oh!
There was an error while loading. Please reload this page.