Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ebonite/core/objects/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def _exposed_methods_mapping(self) -> typing.Dict[str, str]:
this allows to wrap existing API with your own pre/postprocessing.
Otherwise, wrapped model object method is going to be called.
"""
pass # pragma: no cover

@staticmethod
def with_model(f):
Expand Down
3 changes: 2 additions & 1 deletion src/ebonite/ext/ext_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class ExtensionLoader:
Extension('ebonite.ext.imageio', ['imageio']),
Extension('ebonite.ext.lightgbm', ['lightgbm'], False),
Extension('ebonite.ext.xgboost', ['xgboost'], False),
Extension('ebonite.ext.docker', ['docker'], False)
Extension('ebonite.ext.docker', ['docker'], False),
Extension('ebonite.ext.onnx', ['onnx'], False)
)

_loaded_extensions: Dict[Extension, ModuleType] = {}
Expand Down
3 changes: 3 additions & 0 deletions src/ebonite/ext/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import ONNXModelHook, ONNXModelWrapper

__all__ = ['ONNXModelWrapper', 'ONNXModelHook']
107 changes: 107 additions & 0 deletions src/ebonite/ext/onnx/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextlib
import os
from abc import abstractmethod
from typing import Dict, List, Type, Union

import onnx
from onnx import ModelProto
from pyjackson.decorators import cached_property

from ebonite.core.analyzer import TypeHookMixin
from ebonite.core.analyzer.model import BindingModelHook
from ebonite.core.objects import ModelIO, ModelWrapper, Requirements
from ebonite.core.objects.artifacts import Blobs, InMemoryBlob
from ebonite.core.objects.wrapper import FilesContextManager
from ebonite.utils.abc_utils import is_abstract_method
from ebonite.utils.importing import module_importable

_DEFAULT_BACKEND = 'onnxruntime'


def set_default_onnx_backend(backend: Union[Type['ONNXInferenceBackend'], str]):
global _DEFAULT_BACKEND
if not isinstance(backend, str):
backend = backend.name
if backend not in ONNXInferenceBackend.subtypes:
raise ValueError(f'unknown onnx backend {backend}')
_DEFAULT_BACKEND = backend


class ONNXModelIO(ModelIO):
FILENAME = 'model.onnx'

@contextlib.contextmanager
def dump(self, model: ModelProto) -> FilesContextManager:
yield Blobs({self.FILENAME: InMemoryBlob(model.SerializeToString())}) # TODO change to LazyBlob
Comment thread
mike0sv marked this conversation as resolved.

def load(self, path):
return onnx.load(os.path.join(path, self.FILENAME))


class ONNXInferenceBackend:
subtypes: Dict[str, Type['ONNXInferenceBackend']] = {}
name: str
requirements: List[str]

def __init__(self, wrapper: 'ONNXModelWrapper'):
self.wrapper = wrapper

def __init_subclass__(cls, **kwargs):
if not is_abstract_method(cls.run) and not is_abstract_method(cls.is_available):
Comment thread
mike0sv marked this conversation as resolved.
if 'name' not in cls.__dict__ or 'requirements' not in cls.__dict__:
raise AttributeError(f'provide name and requirements fields for {cls}')
ONNXInferenceBackend.subtypes[cls.name] = cls
super().__init_subclass__(**kwargs)

@abstractmethod
def run(self, data):
""""""

@classmethod
def is_available(cls):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test what? it's one line and calls one function, which is tested separately. and it will be covered in report since it will be called within other tests

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return all(module_importable(m) for m in cls.requirements)


class ONNXRuntimeBackend(ONNXInferenceBackend):
name = 'onnxruntime'
requirements = ['onnxruntime']
_session = None

@cached_property
def session(self):
import onnxruntime as rt
if self._session is None:
Comment thread
mike0sv marked this conversation as resolved.
self._session = rt.InferenceSession(self.wrapper.model.SerializeToString())
return self._session

def run(self, data):
return self.session.run([o.name for o in self.session.get_outputs()], data)


class ONNXModelWrapper(ModelWrapper):
model: ModelProto

def __init__(self, io: ModelIO, backend: str = None):
super().__init__(io)
self.backend = backend or _DEFAULT_BACKEND
if self.backend not in ONNXInferenceBackend.subtypes:
raise ValueError(f'unknown onnx backend {self.backend}')
self._backend: ONNXInferenceBackend = ONNXInferenceBackend.subtypes[self.backend](self)

def run(self, data):
if not self._backend.is_available():
raise RuntimeError(f'{self.backend} inference backend is unavailable')
return self._backend.run(data)

def _exposed_methods_mapping(self) -> Dict[str, str]:
return {'predict': 'run'}

def _model_requirements(self) -> Requirements:
return super(ONNXModelWrapper, self)._model_requirements() + self._backend.requirements


class ONNXModelHook(BindingModelHook, TypeHookMixin):
valid_types = [ModelProto]

def _wrapper_factory(self) -> ModelWrapper:
return ONNXModelWrapper(ONNXModelIO())
2 changes: 2 additions & 0 deletions test.requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ testcontainers==2.6.0
pytest==5.2.2
xgboost==1.0.2
lightgbm==2.3.1
onnx==1.7.0
onnxruntime==1.4.0

torch==1.4.0+cpu ; sys_platform != "darwin"

Expand Down
Empty file added tests/ext/test_onnx/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions tests/ext/test_onnx/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np
import onnx
import pytest
from onnxruntime.datasets import get_example
from pyjackson import deserialize, serialize

from ebonite.core.analyzer.model import ModelAnalyzer
from ebonite.core.objects import ModelWrapper
from ebonite.ext.onnx.model import ONNXInferenceBackend, ONNXModelWrapper


@pytest.fixture
def onnx_model():
return onnx.load(get_example('sigmoid.onnx'))


@pytest.fixture
def onnx_input():
x = np.random.random((3, 4, 5)).astype(np.float32)
return {'x': x}


@pytest.fixture
def onnx_wrapper(onnx_model, onnx_input):
return ModelAnalyzer.analyze(onnx_model, input_data=onnx_input)


def test_onnx_hook(onnx_wrapper):
assert isinstance(onnx_wrapper, ONNXModelWrapper)


def test_onnx_io(onnx_wrapper: ModelWrapper, tmpdir, onnx_input):
with onnx_wrapper.dump() as artifacts:
artifacts.materialize(tmpdir)

onnx_wrapper: ONNXModelWrapper = deserialize(serialize(onnx_wrapper), ModelWrapper)
assert isinstance(onnx_wrapper, ONNXModelWrapper)
onnx_wrapper.load(tmpdir)
predict = onnx_wrapper.run(onnx_input)
assert isinstance(predict, list)
assert len(predict) == 1
tensor = predict[0]
assert isinstance(tensor, np.ndarray)


def test_onnx_backend_subclass():
class AbstractSubclass(ONNXInferenceBackend):
name = 'abstract'
requirements = []

assert AbstractSubclass.name not in ONNXInferenceBackend.subtypes

with pytest.raises(AttributeError):
class NoName(ONNXInferenceBackend):
requirements = []

def run(self, data):
pass

with pytest.raises(AttributeError):
class NoReqs(ONNXInferenceBackend):
name = 'aaa'

def run(self, data):
pass

class Good(ONNXInferenceBackend):
name = 'aaa'
requirements = []

def run(self, data):
pass

assert Good.name in ONNXInferenceBackend.subtypes
del ONNXInferenceBackend.subtypes[Good.name]