From 37d9e1ec9097441d96b6acb606f9f0ec7598ba75 Mon Sep 17 00:00:00 2001 From: mike0sv Date: Wed, 28 Dec 2022 18:53:05 +0300 Subject: [PATCH 1/4] WIP transformers support --- mlem/contrib/transformers.py | 171 +++++++++++++++++++++++++++++++++++ mlem/core/metadata.py | 9 +- mlem/core/model.py | 2 +- mlem/core/objects.py | 12 +-- mlem/ext.py | 5 + setup.py | 2 + 6 files changed, 189 insertions(+), 12 deletions(-) create mode 100644 mlem/contrib/transformers.py diff --git a/mlem/contrib/transformers.py b/mlem/contrib/transformers.py new file mode 100644 index 00000000..6b076d85 --- /dev/null +++ b/mlem/contrib/transformers.py @@ -0,0 +1,171 @@ +import os +import tempfile +from enum import Enum +from importlib import import_module +from typing import Any, ClassVar, Dict, Optional, Type, Union + +from pydantic import BaseModel +from transformers import ( + AutoModel, + AutoTokenizer, + BatchEncoding, + PreTrainedTokenizer, + TensorType, +) +from transformers.modeling_utils import PreTrainedModel + +from mlem.core.artifacts import Artifacts +from mlem.core.data_type import ( + DataHook, + DataSerializer, + DataType, + DataWriter, + JsonTypes, + WithDefaultSerializer, +) +from mlem.core.hooks import IsInstanceHookMixin +from mlem.core.model import BufferModelIO, ModelHook, ModelType, Signature +from mlem.core.requirements import InstallableRequirement, Requirements + + +class ObjectType(str, Enum): + MODEL = "model" + TOKENIZER = "tokenizer" + + +_loaders = {ObjectType.MODEL: AutoModel, ObjectType.TOKENIZER: AutoTokenizer} + +_bases = { + PreTrainedModel: ObjectType.MODEL, + PreTrainedTokenizer: ObjectType.TOKENIZER, +} + + +def get_object_type(obj) -> ObjectType: + for base, obj_type in _bases.items(): + if isinstance(obj, base): + return obj_type + raise ValueError(f"Cannot determine object type for {obj}") + + +class TransformersIO(BufferModelIO): + type: ClassVar = "transformers" + + class Config: + use_enum_values = True + + obj_type: ObjectType + + def save_model(self, model: PreTrainedModel, path: str): + model.save_pretrained(path) + + @property + def load_class(self): + return _loaders[self.obj_type] + + def load(self, artifacts: Artifacts): + with tempfile.TemporaryDirectory() as tmpdir: + for name, art in artifacts.items(): + art.materialize(os.path.join(tmpdir, name)) + return self.load_class.from_pretrained(tmpdir) + + +class TokenizerModelType(ModelType, ModelHook, IsInstanceHookMixin): + type: ClassVar = "transformers" + valid_types: ClassVar = (PreTrainedModel, PreTrainedTokenizer) + + class Config: + use_enum_values = True + + return_tensors: Optional[TensorType] = None + io: TransformersIO + + @classmethod + def process( + cls, + obj: Any, + sample_data: Optional[Any] = None, + methods_sample_data: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> ModelType: + call_kwargs = {} + return_tensors = kwargs.get("return_tensors") + if return_tensors: + call_kwargs["return_tensors"] = return_tensors + sample_data = (methods_sample_data or {}).get("__call__", sample_data) + return TokenizerModelType( + methods={ + "__call__": Signature.from_method( + obj.__call__, + sample_data, + auto_infer=sample_data is not None, + **call_kwargs, + ) + }, + io=TransformersIO(obj_type=get_object_type(obj)), + ) + + def get_requirements(self) -> Requirements: + reqs = super().get_requirements() + if self.io.obj_type == ObjectType.TOKENIZER: + try: + reqs += InstallableRequirement.from_module( + import_module("sentencepiece") + ) + reqs += InstallableRequirement.from_module( + import_module("google.protobuf"), package_name="protobuf" + ) + except ImportError: + pass + return reqs + + +class BatchEncodingType( + WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin +): + class Config: + use_enum_values = True + + valid_types: ClassVar = BatchEncoding + return_tensors: Optional[TensorType] = None + + @staticmethod + def get_tensors_type(obj: BatchEncoding) -> TensorType: + types = {type(v) for v in obj.values()} + if len(types) > 1: + raise ValueError(f"Mixed tensor types in {obj}") + type_ = next(iter(types)) + if type_.__module__ == "torch": + return TensorType.PYTORCH + + raise ValueError(f"Unknown tensor type {type_}") + + @classmethod + def process(cls, obj: BatchEncoding, **kwargs) -> DataType: + return BatchEncodingType(return_tensors=cls.get_tensors_type(obj)) + + def get_requirements(self) -> Requirements: + return Requirements.new("transformers") + + def get_writer( + self, project: str = None, filename: str = None, **kwargs + ) -> DataWriter: + pass + + +class BatchEncodingSerializer(DataSerializer[BatchEncodingType]): + data_class: ClassVar = BatchEncodingType + is_default: ClassVar = True + + def serialize( + self, data_type: BatchEncodingType, instance: Any + ) -> JsonTypes: + pass + + def deserialize(self, data_type: BatchEncodingType, obj: JsonTypes) -> Any: + pass + + def get_model( + self, data_type: BatchEncodingType, prefix: str = "" + ) -> Union[Type[BaseModel], type]: + pass diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index 06fab034..163a4cc3 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -39,14 +39,12 @@ def get_object_metadata( params: Dict[str, str] = None, preprocess: Union[Any, Dict[str, Any]] = None, postprocess: Union[Any, Dict[str, Any]] = None, + **kwargs, ) -> Union[MlemData, MlemModel]: """Convert given object to appropriate MlemObject subclass""" if preprocess is None and postprocess is None: try: - return MlemData.from_data( - obj, - params=params, - ) + return MlemData.from_data(obj, params=params, **kwargs) except HookNotFound: pass @@ -56,6 +54,7 @@ def get_object_metadata( params=params, preprocess=preprocess, postprocess=postprocess, + **kwargs, ) @@ -100,6 +99,7 @@ def save( params: Dict[str, str] = None, preprocess: Union[Any, Dict[str, Any]] = None, postprocess: Union[Any, Dict[str, Any]] = None, + **kwargs, ) -> MlemObject: """Saves given object to a given path @@ -125,6 +125,7 @@ def save( params=params, preprocess=preprocess, postprocess=postprocess, + **kwargs, ) log_meta_params(meta, add_object_type=True) path = os.fspath(path) diff --git a/mlem/core/model.py b/mlem/core/model.py index ee6e3a97..6eddc268 100644 --- a/mlem/core/model.py +++ b/mlem/core/model.py @@ -131,7 +131,7 @@ def from_argspec( f"auto_infer=True, but no value for {name} argument" ) type_ = DataAnalyzer.analyze( - defaults.get(name, call_kwargs.get(name)) + call_kwargs.get(name, defaults.get(name)) ) else: type_ = UnspecifiedDataType() diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 1d936efc..cfd0a17a 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -732,12 +732,13 @@ def from_obj( params: Dict[str, str] = None, preprocess: Union[Any, Dict[str, Any]] = None, postprocess: Union[Any, Dict[str, Any]] = None, + **kwargs, ) -> "MlemModel": mlem_model = MlemModel( params=params or {}, ) model_hook = ModelAnalyzer.find_hook(model) - model_type = model_hook.process(model) + model_type = model_hook.process(model, **kwargs) methods = set(model_type.methods) if ( methods_sample_data is not None @@ -763,6 +764,7 @@ def from_obj( model, sample_data=sample_data, methods_sample_data=_methods_sample_data, + **kwargs, ) if mt.model is None: mt = mt.bind(model) @@ -905,13 +907,9 @@ def data(self): @classmethod def from_data( - cls, - data: Any, - params: Dict[str, str] = None, + cls, data: Any, params: Dict[str, str] = None, **kwargs ) -> "MlemData": - data_type = DataType.create( - data, - ) + data_type = DataType.create(data, **kwargs) meta = MlemData( requirements=data_type.get_requirements().expanded, params=params or {}, diff --git a/mlem/ext.py b/mlem/ext.py index 9408781d..499e597a 100644 --- a/mlem/ext.py +++ b/mlem/ext.py @@ -99,6 +99,11 @@ class ExtensionLoader: Extension("mlem.contrib.onnx", ["onnx"], False), Extension("mlem.contrib.tensorflow", ["tensorflow"], False), Extension("mlem.contrib.torch", ["torch"], False), + Extension( + "mlem.contrib.transformers", + ["transformers", "sentencepiece"], + False, + ), Extension("mlem.contrib.catboost", ["catboost"], False), # Extension('mlem.contrib.aiohttp', ['aiohttp', 'aiohttp_swagger']), # Extension('mlem.contrib.flask', ['flask', 'flasgger'], False), diff --git a/setup.py b/setup.py index 777ce550..68bf3b44 100644 --- a/setup.py +++ b/setup.py @@ -230,6 +230,8 @@ "serializer.torch = mlem.contrib.torch:TorchTensorSerializer", "data_writer.torch = mlem.contrib.torch:TorchTensorWriter", "serializer.torch_image = mlem.contrib.torchvision:TorchImageSerializer", + "model_type.transformers = mlem.contrib.transformers:TokenizerModelType", + "model_io.transformers = mlem.contrib.transformers:TransformersIO", "builder.conda = mlem.contrib.venv:CondaBuilder", "requirement.conda = mlem.contrib.venv:CondaPackageRequirement", "builder.venv = mlem.contrib.venv:VenvBuilder", From 4691e28c32ad4f808a4aa3b51b1afbd6bb6bfafb Mon Sep 17 00:00:00 2001 From: mike0sv Date: Thu, 29 Dec 2022 15:53:33 +0300 Subject: [PATCH 2/4] Add batchencoding type for tokenizers --- mlem/contrib/numpy.py | 5 ++- mlem/contrib/tensorflow.py | 4 +- mlem/contrib/torch.py | 10 ++++- mlem/contrib/transformers.py | 87 ++++++++++++++++++++---------------- 4 files changed, 63 insertions(+), 43 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 5861b11f..05d6d786 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -107,7 +107,10 @@ def _abstract_shape(shape): @classmethod def process(cls, obj, **kwargs) -> DataType: return NumpyNdarrayType( - shape=cls._abstract_shape(obj.shape), dtype=obj.dtype.name + shape=cls._abstract_shape(obj.shape) + if not kwargs["is_dynamic"] + else tuple(None for _ in obj.shape), + dtype=obj.dtype.name, ) @classmethod diff --git a/mlem/contrib/tensorflow.py b/mlem/contrib/tensorflow.py index 6458eb98..0e0fed49 100644 --- a/mlem/contrib/tensorflow.py +++ b/mlem/contrib/tensorflow.py @@ -85,7 +85,9 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): @classmethod def process(cls, obj: tf.Tensor, **kwargs) -> DataType: return TFTensorDataType( - shape=(None,) + tuple(obj.shape)[1:], + shape=(None,) + tuple(obj.shape)[1:] + if not kwargs["is_dynamic"] + else tuple(None for _ in obj.shape), dtype=obj.dtype.name, ) diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index 4bd0c196..399e746c 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -66,7 +66,11 @@ class TorchTensorDataType( """Type name of `torch.Tensor` elements""" def check_shape(self, tensor, exc_type): - if tuple(tensor.shape)[1:] != self.shape[1:]: + shape = tuple( + s if s is not None else tensor.shape[i] + for i, s in enumerate(self.shape) + ) + if tuple(tensor.shape) != shape: raise exc_type( f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}" ) @@ -91,7 +95,9 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): @classmethod def process(cls, obj: torch.Tensor, **kwargs) -> DataType: return TorchTensorDataType( - shape=(None,) + obj.shape[1:], + shape=(None,) + obj.shape[1:] + if not kwargs["is_dynamic"] + else tuple(None for _ in obj.shape), dtype=str(obj.dtype)[len("torch") + 1 :], ) diff --git a/mlem/contrib/transformers.py b/mlem/contrib/transformers.py index 6b076d85..f5d66320 100644 --- a/mlem/contrib/transformers.py +++ b/mlem/contrib/transformers.py @@ -2,9 +2,8 @@ import tempfile from enum import Enum from importlib import import_module -from typing import Any, ClassVar, Dict, Optional, Type, Union +from typing import Any, ClassVar, Dict, Optional -from pydantic import BaseModel from transformers import ( AutoModel, AutoTokenizer, @@ -16,12 +15,11 @@ from mlem.core.artifacts import Artifacts from mlem.core.data_type import ( + DataAnalyzer, DataHook, - DataSerializer, DataType, - DataWriter, - JsonTypes, - WithDefaultSerializer, + DictSerializer, + DictType, ) from mlem.core.hooks import IsInstanceHookMixin from mlem.core.model import BufferModelIO, ModelHook, ModelType, Signature @@ -93,15 +91,17 @@ def process( if return_tensors: call_kwargs["return_tensors"] = return_tensors sample_data = (methods_sample_data or {}).get("__call__", sample_data) + signature = Signature.from_method( + obj.__call__, + sample_data, + auto_infer=sample_data is not None, + **call_kwargs, + ) + [a for a in signature.args if a.name == "return_tensors"][ + 0 + ].default = return_tensors return TokenizerModelType( - methods={ - "__call__": Signature.from_method( - obj.__call__, - sample_data, - auto_infer=sample_data is not None, - **call_kwargs, - ) - }, + methods={"__call__": signature}, io=TransformersIO(obj_type=get_object_type(obj)), ) @@ -120,9 +120,14 @@ def get_requirements(self) -> Requirements: return reqs -class BatchEncodingType( - WithDefaultSerializer, DataType, DataHook, IsInstanceHookMixin -): +_ADDITIONAL_DEPS = { + TensorType.NUMPY: "numpy", + TensorType.PYTORCH: "torch", + TensorType.TENSORFLOW: "tensorflow", +} + + +class BatchEncodingType(DictType, DataHook, IsInstanceHookMixin): class Config: use_enum_values = True @@ -130,42 +135,46 @@ class Config: return_tensors: Optional[TensorType] = None @staticmethod - def get_tensors_type(obj: BatchEncoding) -> TensorType: + def get_tensors_type(obj: BatchEncoding) -> Optional[TensorType]: types = {type(v) for v in obj.values()} if len(types) > 1: raise ValueError(f"Mixed tensor types in {obj}") type_ = next(iter(types)) if type_.__module__ == "torch": return TensorType.PYTORCH - + if type_.__module__.startswith("tensorflow"): + return TensorType.TENSORFLOW + if type_.__module__.startswith("numpy"): + return TensorType.NUMPY + if type_ is list: + return None raise ValueError(f"Unknown tensor type {type_}") @classmethod def process(cls, obj: BatchEncoding, **kwargs) -> DataType: - return BatchEncodingType(return_tensors=cls.get_tensors_type(obj)) + return BatchEncodingType( + return_tensors=cls.get_tensors_type(obj), + item_types={ + k: DataAnalyzer.analyze(v, is_dynamic=True, **kwargs) + for (k, v) in obj.items() + }, + ) def get_requirements(self) -> Requirements: - return Requirements.new("transformers") - - def get_writer( - self, project: str = None, filename: str = None, **kwargs - ) -> DataWriter: - pass + new = Requirements.new("transformers") + if self.return_tensors in _ADDITIONAL_DEPS: + new += Requirements.new(_ADDITIONAL_DEPS[self.return_tensors]) + return new -class BatchEncodingSerializer(DataSerializer[BatchEncodingType]): +class BatchEncodingSerializer(DictSerializer): data_class: ClassVar = BatchEncodingType is_default: ClassVar = True - def serialize( - self, data_type: BatchEncodingType, instance: Any - ) -> JsonTypes: - pass - - def deserialize(self, data_type: BatchEncodingType, obj: JsonTypes) -> Any: - pass - - def get_model( - self, data_type: BatchEncodingType, prefix: str = "" - ) -> Union[Type[BaseModel], type]: - pass + @staticmethod + def _check_type_and_keys(data_type, obj, exc_type): + data_type.check_type(obj, BatchEncoding, exc_type) + if set(obj.keys()) != set(data_type.item_types.keys()): + raise exc_type( + f"given dict has keys: {set(obj.keys())}, expected: {set(data_type.item_types.keys())}" + ) From ef3c5d778bfce47fc53de61b3154b288d61ba74f Mon Sep 17 00:00:00 2001 From: mike0sv Date: Fri, 30 Dec 2022 19:40:51 +0300 Subject: [PATCH 3/4] Add tests for batchencoding type --- mlem/contrib/numpy.py | 8 ++ mlem/contrib/tensorflow.py | 9 +- mlem/contrib/torch.py | 13 ++- mlem/contrib/transformers.py | 64 ++++++++++- mlem/core/data_type.py | 2 +- tests/contrib/test_transformers.py | 175 +++++++++++++++++++++++++++++ 6 files changed, 256 insertions(+), 15 deletions(-) create mode 100644 tests/contrib/test_transformers.py diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 05d6d786..4e7b67af 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -261,3 +261,11 @@ def read_batch( self, artifacts: Artifacts, batch_size: int ) -> Iterator[DataType]: raise NotImplementedError + + +def apply_shape_pattern( + abs_shape: Tuple[Optional[int], ...], shape: Tuple[int, ...] +): + return tuple( + s if s is not None else shape[i] for i, s in enumerate(abs_shape) + ) diff --git a/mlem/contrib/tensorflow.py b/mlem/contrib/tensorflow.py index 0e0fed49..2d8cb305 100644 --- a/mlem/contrib/tensorflow.py +++ b/mlem/contrib/tensorflow.py @@ -14,7 +14,10 @@ from pydantic import conlist, create_model from tensorflow.python.keras.saving.saved_model_experimental import sequential -from mlem.contrib.numpy import python_type_from_np_string_repr +from mlem.contrib.numpy import ( + apply_shape_pattern, + python_type_from_np_string_repr, +) from mlem.core.artifacts import Artifacts, Storage from mlem.core.data_type import ( DataHook, @@ -60,7 +63,9 @@ def tf_type(self): return getattr(tf, self.dtype) def check_shape(self, tensor, exc_type): - if tuple(tensor.shape)[1:] != self.shape[1:]: + if tuple(tensor.shape) != apply_shape_pattern( + self.shape, tensor.shape + ): raise exc_type( f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}" ) diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index 399e746c..d59fcd2d 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -13,7 +13,10 @@ from pydantic import conlist, create_model from mlem.config import MlemConfigBase -from mlem.contrib.numpy import python_type_from_np_string_repr +from mlem.contrib.numpy import ( + apply_shape_pattern, + python_type_from_np_string_repr, +) from mlem.core.artifacts import Artifacts, FSSpecArtifact, Storage from mlem.core.data_type import ( DataHook, @@ -66,11 +69,9 @@ class TorchTensorDataType( """Type name of `torch.Tensor` elements""" def check_shape(self, tensor, exc_type): - shape = tuple( - s if s is not None else tensor.shape[i] - for i, s in enumerate(self.shape) - ) - if tuple(tensor.shape) != shape: + if tuple(tensor.shape) != apply_shape_pattern( + self.shape, tensor.shape + ): raise exc_type( f"given tensor is of shape: {(None,) + tuple(tensor.shape)[1:]}, expected: {self.shape}" ) diff --git a/mlem/contrib/transformers.py b/mlem/contrib/transformers.py index f5d66320..cf1943f2 100644 --- a/mlem/contrib/transformers.py +++ b/mlem/contrib/transformers.py @@ -2,7 +2,7 @@ import tempfile from enum import Enum from importlib import import_module -from typing import Any, ClassVar, Dict, Optional +from typing import Any, ClassVar, Dict, Iterator, Optional, Tuple from transformers import ( AutoModel, @@ -13,13 +13,16 @@ ) from transformers.modeling_utils import PreTrainedModel -from mlem.core.artifacts import Artifacts +from mlem.core.artifacts import Artifacts, Storage from mlem.core.data_type import ( DataAnalyzer, DataHook, DataType, + DataWriter, + DictReader, DictSerializer, DictType, + DictWriter, ) from mlem.core.hooks import IsInstanceHookMixin from mlem.core.model import BufferModelIO, ModelHook, ModelType, Signature @@ -120,7 +123,7 @@ def get_requirements(self) -> Requirements: return reqs -_ADDITIONAL_DEPS = { +ADDITIONAL_DEPS = { TensorType.NUMPY: "numpy", TensorType.PYTORCH: "torch", TensorType.TENSORFLOW: "tensorflow", @@ -131,6 +134,7 @@ class BatchEncodingType(DictType, DataHook, IsInstanceHookMixin): class Config: use_enum_values = True + type: ClassVar = "batch_encoding" valid_types: ClassVar = BatchEncoding return_tensors: Optional[TensorType] = None @@ -150,6 +154,14 @@ def get_tensors_type(obj: BatchEncoding) -> Optional[TensorType]: return None raise ValueError(f"Unknown tensor type {type_}") + @property + def return_tensors_enum(self) -> Optional[TensorType]: + if self.return_tensors is not None and not isinstance( + self.return_tensors, TensorType + ): + return TensorType(self.return_tensors) + return self.return_tensors + @classmethod def process(cls, obj: BatchEncoding, **kwargs) -> DataType: return BatchEncodingType( @@ -162,10 +174,15 @@ def process(cls, obj: BatchEncoding, **kwargs) -> DataType: def get_requirements(self) -> Requirements: new = Requirements.new("transformers") - if self.return_tensors in _ADDITIONAL_DEPS: - new += Requirements.new(_ADDITIONAL_DEPS[self.return_tensors]) + if self.return_tensors_enum in ADDITIONAL_DEPS: + new += Requirements.new(ADDITIONAL_DEPS[self.return_tensors_enum]) return new + def get_writer( + self, project: str = None, filename: str = None, **kwargs + ) -> DataWriter: + return BatchEncodingWriter(**kwargs) + class BatchEncodingSerializer(DictSerializer): data_class: ClassVar = BatchEncodingType @@ -173,8 +190,43 @@ class BatchEncodingSerializer(DictSerializer): @staticmethod def _check_type_and_keys(data_type, obj, exc_type): - data_type.check_type(obj, BatchEncoding, exc_type) + data_type.check_type(obj, (dict, BatchEncoding), exc_type) if set(obj.keys()) != set(data_type.item_types.keys()): raise exc_type( f"given dict has keys: {set(obj.keys())}, expected: {set(data_type.item_types.keys())}" ) + + def deserialize(self, data_type: DictType, obj): + assert isinstance(data_type, BatchEncodingType) + return BatchEncoding( + super().deserialize(data_type, obj), + tensor_type=data_type.return_tensors_enum, + ) + + +class BatchEncodingReader(DictReader): + type: ClassVar = "batch_encoding" + + def read(self, artifacts: Artifacts) -> DictType: + res = super().read(artifacts) + return res.bind(BatchEncoding(res.data)) + + def read_batch( + self, artifacts: Artifacts, batch_size: int + ) -> Iterator[DictType]: + raise NotImplementedError + + +class BatchEncodingWriter(DictWriter): + type: ClassVar = "batch_encoding" + + def write( + self, data: DataType, storage: Storage, path: str + ) -> Tuple[DictReader, Artifacts]: + res, art = super().write(data, storage, path) + return ( + BatchEncodingReader( + data_type=res.data_type, item_readers=res.item_readers + ), + art, + ) diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 6a16c97f..1013bfbb 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -816,7 +816,7 @@ class DictWriter(DataWriter): def write( self, data: DataType, storage: Storage, path: str - ) -> Tuple[DataReader, Artifacts]: + ) -> Tuple["DictReader", Artifacts]: if not isinstance(data, DictType): raise ValueError( f"expected data to be of DictType, got {type(data)} instead" diff --git a/tests/contrib/test_transformers.py b/tests/contrib/test_transformers.py new file mode 100644 index 00000000..46280257 --- /dev/null +++ b/tests/contrib/test_transformers.py @@ -0,0 +1,175 @@ +from functools import partial + +import numpy as np +import pytest +import tensorflow as tf +import torch +from pydantic import parse_obj_as +from transformers import ( + AlbertModel, + AlbertTokenizer, + BatchEncoding, + DistilBertModel, + DistilBertTokenizer, + TensorType, +) + +from mlem.contrib.transformers import ADDITIONAL_DEPS, BatchEncodingType +from mlem.core.data_type import DataAnalyzer, DataType +from tests.conftest import data_write_read_check + +FULL_TESTS = True + +TOKENIZERS = { + AlbertTokenizer: "albert-base-v2", + DistilBertTokenizer: "distilbert-base-uncased", +} + +MODELS = { + AlbertModel: "albert-base-v2", + DistilBertModel: "distilbert-base-uncased", +} + +ONE_MODEL = AlbertModel +ONE_TOKENIZER = AlbertTokenizer + +for_model = pytest.mark.parametrize( + "model", + [ONE_MODEL.from_pretrained(MODELS[ONE_MODEL])] + if not FULL_TESTS + else [m.from_pretrained(v) for m, v in MODELS.items()], +) + +for_tokenizer = pytest.mark.parametrize( + "tokenizer", + [ONE_TOKENIZER.from_pretrained(TOKENIZERS[ONE_TOKENIZER])] + if not FULL_TESTS + else [m.from_pretrained(v) for m, v in TOKENIZERS.items()], +) + + +def test_analyzing_model(): + pass + + +def test_analyzing_tokenizer(): + pass + + +def test_serving_model(): + pass + + +def test_serving_tokenizer(): + pass + + +def test_model_reqs(): + pass + + +def test_tokenizer_reqs(): + pass + + +# pylint: disable=protected-access +@for_tokenizer +@pytest.mark.parametrize( + "return_tensors,typename,eq", + [ + ("pt", "TorchTensor", lambda a, b: torch.all(a.eq(b))), + ("tf", "TFTensor", lambda a, b: tf.equal(a, b)._numpy().all()), + ("np", "NumpyNdarray", lambda a, b: np.equal(a, b).all()), + (None, "Array", None), + ], +) +def test_batch_encoding(tokenizer, return_tensors, typename, eq): + data = tokenizer("aaa bbb", return_tensors=return_tensors) + + data_type = DataAnalyzer.analyze(data) + assert isinstance(data_type, BatchEncodingType) + expected_reqs = ["transformers"] + if return_tensors is not None: + expected_reqs += [ADDITIONAL_DEPS[TensorType(return_tensors)]] + assert data_type.get_requirements().modules == expected_reqs + + item_type = DataAnalyzer.analyze(data["input_ids"], is_dynamic=True).dict() + expected_payload = { + "item_types": { + "attention_mask": item_type, + "input_ids": item_type, + "token_type_ids": item_type, + }, + "type": "batch_encoding", + } + if return_tensors is not None: + expected_payload["return_tensors"] = return_tensors + if "token_type_ids" not in data: + del expected_payload["item_types"]["token_type_ids"] + assert data_type.dict() == expected_payload + data_type2 = parse_obj_as(DataType, data_type.dict()) + assert data_type2 == data_type + + assert data_type.get_model().__name__ == data_type2.get_model().__name__ + schema_item_type = {"items": {"type": "integer"}, "type": "array"} + if return_tensors is None: + schema_item_type = {"type": "integer"} + expected_schema = { + "definitions": { + f"attention_mask_{typename}": { + "items": schema_item_type, + "title": f"attention_mask_{typename}", + "type": "array", + }, + f"input_ids_{typename}": { + "items": schema_item_type, + "title": f"input_ids_{typename}", + "type": "array", + }, + f"token_type_ids_{typename}": { + "items": schema_item_type, + "title": f"token_type_ids_{typename}", + "type": "array", + }, + }, + "properties": { + "attention_mask": { + "$ref": f"#/definitions/attention_mask_{typename}" + }, + "input_ids": {"$ref": f"#/definitions/input_ids_{typename}"}, + "token_type_ids": { + "$ref": f"#/definitions/token_type_ids_{typename}" + }, + }, + "required": ["input_ids", "token_type_ids", "attention_mask"], + "title": "DictType", + "type": "object", + } + if "token_type_ids" not in data: + del expected_schema["definitions"][f"token_type_ids_{typename}"] + del expected_schema["properties"]["token_type_ids"] + expected_schema["required"].remove("token_type_ids") + assert data_type.get_model().schema() == expected_schema + n_payload = data_type.get_serializer().serialize(data) + deser = data_type.get_serializer().deserialize(n_payload) + assert _batch_encoding_equals(data, deser, eq) + parse_obj_as(data_type.get_model(), n_payload) + + data_type = data_type.bind(data) + data_write_read_check( + data_type, custom_eq=partial(_batch_encoding_equals, equals=eq) + ) + + +def _batch_encoding_equals(first, second, equals): + assert isinstance(first, BatchEncoding) + assert isinstance(second, BatchEncoding) + + assert first.keys() == second.keys() + + for key in first: + if equals is not None: + assert equals(first[key], second[key]) + else: + assert first[key] == second[key] + return True From 40bf1fd4213b0f3695fb8bba45ae798103deab05 Mon Sep 17 00:00:00 2001 From: mike0sv Date: Fri, 30 Dec 2022 19:48:29 +0300 Subject: [PATCH 4/4] lil fix --- mlem/contrib/numpy.py | 2 +- mlem/contrib/tensorflow.py | 2 +- mlem/contrib/torch.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 4e7b67af..9d67c512 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -108,7 +108,7 @@ def _abstract_shape(shape): def process(cls, obj, **kwargs) -> DataType: return NumpyNdarrayType( shape=cls._abstract_shape(obj.shape) - if not kwargs["is_dynamic"] + if not kwargs.get("is_dynamic") else tuple(None for _ in obj.shape), dtype=obj.dtype.name, ) diff --git a/mlem/contrib/tensorflow.py b/mlem/contrib/tensorflow.py index 2d8cb305..00e1f0dc 100644 --- a/mlem/contrib/tensorflow.py +++ b/mlem/contrib/tensorflow.py @@ -91,7 +91,7 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): def process(cls, obj: tf.Tensor, **kwargs) -> DataType: return TFTensorDataType( shape=(None,) + tuple(obj.shape)[1:] - if not kwargs["is_dynamic"] + if not kwargs.get("is_dynamic") else tuple(None for _ in obj.shape), dtype=obj.dtype.name, ) diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index d59fcd2d..d6dceb12 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -97,7 +97,7 @@ def subtype(self, subshape: Tuple[Optional[int], ...]): def process(cls, obj: torch.Tensor, **kwargs) -> DataType: return TorchTensorDataType( shape=(None,) + obj.shape[1:] - if not kwargs["is_dynamic"] + if not kwargs.get("is_dynamic") else tuple(None for _ in obj.shape), dtype=str(obj.dtype)[len("torch") + 1 :], )