Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 5936765

Browse files
authored
Expose params passed to mlem.api.save to fastapi's interface.json (#670)
close #664
1 parent 6ca9d0e commit 5936765

File tree

7 files changed

+57
-10
lines changed

7 files changed

+57
-10
lines changed

mlem/core/metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
def get_object_metadata(
3737
obj: Any,
3838
sample_data=None,
39-
params: Dict[str, str] = None,
39+
params: Dict[str, Any] = None,
4040
preprocess: Union[Any, Dict[str, Any]] = None,
4141
postprocess: Union[Any, Dict[str, Any]] = None,
4242
) -> Union[MlemData, MlemModel]:
@@ -97,7 +97,7 @@ def save(
9797
project: Optional[str] = None,
9898
sample_data=None,
9999
fs: Optional[AbstractFileSystem] = None,
100-
params: Dict[str, str] = None,
100+
params: Dict[str, Any] = None,
101101
preprocess: Union[Any, Dict[str, Any]] = None,
102102
postprocess: Union[Any, Dict[str, Any]] = None,
103103
) -> MlemObject:

mlem/core/objects.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import dataclasses
66
import hashlib
77
import itertools
8+
import json
89
import os
910
import posixpath
1011
import time
@@ -50,6 +51,7 @@
5051
MlemError,
5152
MlemObjectNotFound,
5253
MlemObjectNotSavedError,
54+
SerializationError,
5355
WrongABCType,
5456
WrongMetaSubType,
5557
WrongMetaType,
@@ -91,9 +93,19 @@ class Config:
9193
object_type: ClassVar[str]
9294
location: Optional[Location] = None
9395
"""MlemObject location [transient]"""
94-
params: Dict[str, str] = {}
96+
params: Dict[str, Any] = {}
9597
"""Arbitrary map of additional parameters"""
9698

99+
@validator("params")
100+
def params_are_serializable( # pylint: disable=no-self-argument
101+
cls, value # noqa: B902
102+
):
103+
try:
104+
json.dumps(value)
105+
except TypeError as e:
106+
raise SerializationError(f"Can't serialize object: {value}") from e
107+
return value
108+
97109
@property
98110
def loc(self) -> Location:
99111
if self.location is None:
@@ -751,7 +763,7 @@ def from_obj(
751763
model: Any,
752764
sample_data: Any = None,
753765
methods_sample_data: Dict[str, Any] = None,
754-
params: Dict[str, str] = None,
766+
params: Dict[str, Any] = None,
755767
preprocess: Union[Any, Dict[str, Any]] = None,
756768
postprocess: Union[Any, Dict[str, Any]] = None,
757769
) -> "MlemModel":
@@ -931,7 +943,7 @@ def data(self):
931943
def from_data(
932944
cls,
933945
data: Any,
934-
params: Dict[str, str] = None,
946+
params: Dict[str, Any] = None,
935947
) -> "MlemData":
936948
data_type = DataType.create(
937949
data,

mlem/runtime/interface.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class VersionedInterfaceDescriptor(BaseModel):
9999
methods: InterfaceDescriptor
100100
version: str = mlem.version.__version__
101101
"""mlem version"""
102+
meta: Any
102103

103104

104105
class Interface(ABC, MlemABC):
@@ -201,9 +202,14 @@ def get_descriptor(self) -> InterfaceDescriptor:
201202
}
202203
)
203204

205+
def get_model_meta(self):
206+
return None
207+
204208
def get_versioned_descriptor(self) -> VersionedInterfaceDescriptor:
205209
return VersionedInterfaceDescriptor(
206-
version=mlem.__version__, methods=self.get_descriptor()
210+
version=mlem.__version__,
211+
methods=self.get_descriptor(),
212+
meta=self.get_model_meta(),
207213
)
208214

209215

@@ -267,7 +273,7 @@ def _check_no_signature(data):
267273

268274

269275
class ModelInterface(Interface):
270-
"""Interface that descibes model methods"""
276+
"""Interface that describes model methods"""
271277

272278
type: ClassVar[str] = "model"
273279
model: MlemModel
@@ -352,3 +358,6 @@ def get_method_args(
352358
a.name: a.type_
353359
for a in self.model.model_type.methods[method_name].args
354360
}
361+
362+
def get_model_meta(self):
363+
return self.model.params

mlem/runtime/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,6 @@ def get_method_signature(self, method_name: str) -> InterfaceMethod:
335335
],
336336
returns=self._get_response(method_name, signature.returns),
337337
)
338+
339+
def get_model_meta(self):
340+
return getattr(getattr(self.interface, "model", None), "params", None)

tests/contrib/test_fastapi.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ def test_endpoint(f_client, f_interface: Interface, create_mlem_client, train):
145145
assert response.json() == [0] * 50 + [1] * 50 + [2] * 50
146146

147147

148+
def test_params_exposed_to_interface():
149+
model = MlemModel.from_obj(
150+
lambda x: x, sample_data="sample", params={"a": "b"}
151+
)
152+
interface = ModelInterface.from_model(model)
153+
154+
app = FastAPIServer().app_init(interface)
155+
client = TestClient(app)
156+
157+
docs = client.get("/interface.json")
158+
assert docs.status_code == 200, docs.json()
159+
assert docs.json()["meta"] == {"a": "b"}
160+
161+
148162
@pytest.mark.parametrize(
149163
"data",
150164
[

tests/core/test_metadata.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import shutil
44
import sys
55
import tempfile
6+
from datetime import datetime
67
from pathlib import Path
78
from urllib.parse import quote_plus
89

@@ -16,7 +17,7 @@
1617

1718
from mlem.api import init
1819
from mlem.contrib.heroku.meta import HerokuEnv
19-
from mlem.core.errors import InvalidArgumentError
20+
from mlem.core.errors import InvalidArgumentError, SerializationError
2021
from mlem.core.meta_io import MLEM_EXT
2122
from mlem.core.metadata import (
2223
list_objects,
@@ -42,9 +43,16 @@
4243
@pytest.mark.parametrize("obj", [lazy_fixture("model"), lazy_fixture("train")])
4344
def test_save_with_meta_fields(obj, tmpdir):
4445
path = str(tmpdir / "obj")
45-
save(obj, path, params={"a": "b"})
46+
save(obj, path, params={"a": {"b": ["c", "d", 1]}})
4647
new = load_meta(path)
47-
assert new.params == {"a": "b"}
48+
assert new.params == {"a": {"b": ["c", "d", 1]}}
49+
50+
51+
@pytest.mark.parametrize("obj", [lazy_fixture("model")])
52+
def test_save_with_meta_fields_fails(obj, tmpdir):
53+
path = str(tmpdir / "obj")
54+
with pytest.raises(SerializationError):
55+
save(obj, path, params={"a": datetime.now()})
4856

4957

5058
def test_saving_with_project(model, tmpdir):

tests/runtime/test_interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def test_interface_descriptor__to_dict(interface: Interface):
7070

7171
assert d.dict() == {
7272
"version": mlem.__version__,
73+
"meta": None,
7374
"methods": {
7475
"method1": {
7576
"args": [

0 commit comments

Comments
 (0)