diff --git a/mlflow/__init__.py b/mlflow/__init__.py
index 404fe895eb2a1..5382224912a24 100644
--- a/mlflow/__init__.py
+++ b/mlflow/__init__.py
@@ -126,6 +126,7 @@
active_run,
autolog,
create_experiment,
+ create_logged_model,
delete_experiment,
delete_run,
delete_tag,
@@ -136,6 +137,7 @@
get_artifact_uri,
get_experiment,
get_experiment_by_name,
+ get_logged_model,
get_parent_run,
get_run,
last_active_run,
@@ -153,6 +155,7 @@
log_table,
log_text,
search_experiments,
+ search_logged_models,
search_runs,
set_experiment,
set_experiment_tag,
@@ -173,6 +176,7 @@
"active_run",
"autolog",
"create_experiment",
+ "create_logged_model",
"delete_experiment",
"delete_run",
"delete_tag",
@@ -188,6 +192,7 @@
"get_experiment",
"get_experiment_by_name",
"get_last_active_trace",
+ "get_logged_model",
"get_parent_run",
"get_registry_uri",
"get_run",
@@ -212,6 +217,7 @@
"register_model",
"run",
"search_experiments",
+ "search_logged_models",
"search_model_versions",
"search_registered_models",
"search_runs",
diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py
index 483d39835e95b..4c02c95f43fd8 100644
--- a/mlflow/entities/__init__.py
+++ b/mlflow/entities/__init__.py
@@ -11,12 +11,19 @@
from mlflow.entities.file_info import FileInfo
from mlflow.entities.input_tag import InputTag
from mlflow.entities.lifecycle_stage import LifecycleStage
+from mlflow.entities.logged_model import LoggedModel
from mlflow.entities.metric import Metric
+from mlflow.entities.model_input import ModelInput
+from mlflow.entities.model_output import ModelOutput
+from mlflow.entities.model_param import ModelParam
+from mlflow.entities.model_status import ModelStatus
+from mlflow.entities.model_tag import ModelTag
from mlflow.entities.param import Param
from mlflow.entities.run import Run
from mlflow.entities.run_data import RunData
from mlflow.entities.run_info import RunInfo
from mlflow.entities.run_inputs import RunInputs
+from mlflow.entities.run_outputs import RunOutputs
from mlflow.entities.run_status import RunStatus
from mlflow.entities.run_tag import RunTag
from mlflow.entities.source_type import SourceType
@@ -46,6 +53,7 @@
"InputTag",
"DatasetInput",
"RunInputs",
+ "RunOutputs",
"Span",
"LiveSpan",
"NoOpSpan",
@@ -57,4 +65,10 @@
"TraceInfo",
"SpanStatusCode",
"_DatasetSummary",
+ "LoggedModel",
+ "ModelInput",
+ "ModelOutput",
+ "ModelStatus",
+ "ModelTag",
+ "ModelParam",
]
diff --git a/mlflow/entities/logged_model.py b/mlflow/entities/logged_model.py
new file mode 100644
index 0000000000000..b4fb1baa13757
--- /dev/null
+++ b/mlflow/entities/logged_model.py
@@ -0,0 +1,162 @@
+from typing import Any, Dict, List, Optional, Union
+
+from mlflow.entities._mlflow_object import _MlflowObject
+from mlflow.entities.metric import Metric
+from mlflow.entities.model_param import ModelParam
+from mlflow.entities.model_status import ModelStatus
+from mlflow.entities.model_tag import ModelTag
+
+
+class LoggedModel(_MlflowObject):
+ """
+ MLflow entity representing a Model logged to an MLflow Experiment.
+ """
+
+ def __init__(
+ self,
+ experiment_id: str,
+ model_id: str,
+ name: str,
+ artifact_location: str,
+ creation_timestamp: int,
+ last_updated_timestamp: int,
+ model_type: Optional[str] = None,
+ run_id: Optional[str] = None,
+ status: ModelStatus = ModelStatus.READY,
+ status_message: Optional[str] = None,
+ tags: Optional[Union[List[ModelTag], Dict[str, str]]] = None,
+ params: Optional[Union[List[ModelParam], Dict[str, str]]] = None,
+ metrics: Optional[List[Metric]] = None,
+ ):
+ super().__init__()
+ self._experiment_id: str = experiment_id
+ self._model_id: str = model_id
+ self._name: str = name
+ self._artifact_location: str = artifact_location
+ self._creation_time: int = creation_timestamp
+ self._last_updated_timestamp: int = last_updated_timestamp
+ self._model_type: Optional[str] = model_type
+ self._run_id: Optional[str] = run_id
+ self._status: ModelStatus = status
+ self._status_message: Optional[str] = status_message
+ self._tags: Dict[str, str] = (
+ {tag.key: tag.value for tag in (tags or [])} if isinstance(tags, list) else (tags or {})
+ )
+ self._params: Dict[str, str] = (
+ {param.key: param.value for param in (params or [])}
+ if isinstance(params, list)
+ else (params or {})
+ )
+ self._metrics: Optional[List[Metric]] = metrics
+
+ @property
+ def experiment_id(self) -> str:
+ """String. Experiment ID associated with this Model."""
+ return self._experiment_id
+
+ @experiment_id.setter
+ def experiment_id(self, new_experiment_id: str):
+ self._experiment_id = new_experiment_id
+
+ @property
+ def model_id(self) -> str:
+ """String. Unique ID for this Model."""
+ return self._model_id
+
+ @model_id.setter
+ def model_id(self, new_model_id: str):
+ self._model_id = new_model_id
+
+ @property
+ def name(self) -> str:
+ """String. Name for this Model."""
+ return self._name
+
+ @name.setter
+ def name(self, new_name: str):
+ self._name = new_name
+
+ @property
+ def artifact_location(self) -> str:
+ """String. Location of the model artifacts."""
+ return self._artifact_location
+
+ @artifact_location.setter
+ def artifact_location(self, new_artifact_location: str):
+ self._artifact_location = new_artifact_location
+
+ @property
+ def creation_timestamp(self) -> int:
+ """Integer. Model creation timestamp (milliseconds since the Unix epoch)."""
+ return self._creation_time
+
+ @property
+ def last_updated_timestamp(self) -> int:
+ """Integer. Timestamp of last update for this Model (milliseconds since the Unix
+ epoch).
+ """
+ return self._last_updated_timestamp
+
+ @last_updated_timestamp.setter
+ def last_updated_timestamp(self, updated_timestamp: int):
+ self._last_updated_timestamp = updated_timestamp
+
+ @property
+ def model_type(self) -> Optional[str]:
+ """String. Type of the model."""
+ return self._model_type
+
+ @model_type.setter
+ def model_type(self, new_model_type: Optional[str]):
+ self._model_type = new_model_type
+
+ @property
+ def run_id(self) -> Optional[str]:
+ """String. MLflow run ID that generated this model."""
+ return self._run_id
+
+ @property
+ def status(self) -> ModelStatus:
+ """String. Current status of this Model."""
+ return self._status
+
+ @status.setter
+ def status(self, updated_status: str):
+ self._status = updated_status
+
+ @property
+ def status_message(self) -> Optional[str]:
+ """String. Descriptive message for error status conditions."""
+ return self._status_message
+
+ @property
+ def tags(self) -> Dict[str, str]:
+ """Dictionary of tag key (string) -> tag value for this Model."""
+ return self._tags
+
+ @property
+ def params(self) -> Dict[str, str]:
+ """Model parameters."""
+ return self._params
+
+ @property
+ def metrics(self) -> Optional[List[Metric]]:
+ """List of metrics associated with this Model."""
+ return self._metrics
+
+ @metrics.setter
+ def metrics(self, new_metrics: Optional[List[Metric]]):
+ self._metrics = new_metrics
+
+ @classmethod
+ def _properties(cls) -> List[str]:
+ # aggregate with base class properties since cls.__dict__ does not do it automatically
+ return sorted(cls._get_properties_helper())
+
+ def _add_tag(self, tag):
+ self._tags[tag.key] = tag.value
+
+ def to_dictionary(self) -> Dict[str, Any]:
+ model_dict = dict(self)
+ model_dict["status"] = str(self.status)
+ return model_dict
diff --git a/mlflow/entities/metric.py b/mlflow/entities/metric.py
index bea6926b95d1f..68bd68451a150 100644
--- a/mlflow/entities/metric.py
+++ b/mlflow/entities/metric.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
@@ -10,11 +12,31 @@ class Metric(_MlflowObject):
Metric object.
"""
- def __init__(self, key, value, timestamp, step):
+ def __init__(
+ self,
+ key,
+ value,
+ timestamp,
+ step,
+ model_id: Optional[str] = None,
+ dataset_name: Optional[str] = None,
+ dataset_digest: Optional[str] = None,
+ run_id: Optional[str] = None,
+ ):
+ if (dataset_name, dataset_digest).count(None) == 1:
+ raise MlflowException(
+ "Both dataset_name and dataset_digest must be provided if one is provided",
+ INVALID_PARAMETER_VALUE,
+ )
+
self._key = key
self._value = value
self._timestamp = timestamp
self._step = step
+ self._model_id = model_id
+ self._dataset_name = dataset_name
+ self._dataset_digest = dataset_digest
+ self._run_id = run_id
@property
def key(self):
@@ -36,16 +58,38 @@ def step(self):
"""Integer metric step (x-coordinate)."""
return self._step
+ @property
+ def model_id(self):
+ """ID of the Model associated with the metric."""
+ return self._model_id
+
+ @property
+ def dataset_name(self) -> Optional[str]:
+ """String. Name of the dataset associated with the metric."""
+ return self._dataset_name
+
+ @property
+ def dataset_digest(self) -> Optional[str]:
+ """String. Digest of the dataset associated with the metric."""
+ return self._dataset_digest
+
+ @property
+ def run_id(self) -> Optional[str]:
+ """String. Run ID associated with the metric."""
+ return self._run_id
+
def to_proto(self):
metric = ProtoMetric()
metric.key = self.key
metric.value = self.value
metric.timestamp = self.timestamp
metric.step = self.step
+ # TODO: Add model_id, dataset_name, dataset_digest, and run_id to the proto
return metric
@classmethod
def from_proto(cls, proto):
+ # TODO: Add model_id, dataset_name, dataset_digest, and run_id to the proto
return cls(proto.key, proto.value, proto.timestamp, proto.step)
def __eq__(self, __o):
@@ -69,6 +113,10 @@ def to_dictionary(self):
"value": self.value,
"timestamp": self.timestamp,
"step": self.step,
+ "model_id": self.model_id,
+ "dataset_name": self.dataset_name,
+ "dataset_digest": self.dataset_digest,
+ "run_id": self._run_id,
}
@classmethod
diff --git a/mlflow/entities/model_input.py b/mlflow/entities/model_input.py
new file mode 100644
index 0000000000000..456c6db70ae5f
--- /dev/null
+++ b/mlflow/entities/model_input.py
@@ -0,0 +1,18 @@
+from mlflow.entities._mlflow_object import _MlflowObject
+
+
+class ModelInput(_MlflowObject):
+ """ModelInput object associated with a Run."""
+
+ def __init__(self, model_id: str):
+ self._model_id = model_id
+
+ def __eq__(self, other: _MlflowObject) -> bool:
+ if type(other) is type(self):
+ return self.__dict__ == other.__dict__
+ return False
+
+ @property
+ def model_id(self) -> str:
+ """Model ID."""
+ return self._model_id
diff --git a/mlflow/entities/model_output.py b/mlflow/entities/model_output.py
new file mode 100644
index 0000000000000..058a50b316271
--- /dev/null
+++ b/mlflow/entities/model_output.py
@@ -0,0 +1,24 @@
+from mlflow.entities._mlflow_object import _MlflowObject
+
+
+class ModelOutput(_MlflowObject):
+ """ModelOutput object associated with a Run."""
+
+ def __init__(self, model_id: str, step: int) -> None:
+ self._model_id = model_id
+ self._step = step
+
+ def __eq__(self, other: _MlflowObject) -> bool:
+ if type(other) is type(self):
+ return self.__dict__ == other.__dict__
+ return False
+
+ @property
+ def model_id(self) -> str:
+ """Model ID"""
+ return self._model_id
+
+ @property
+ def step(self) -> str:
+ """Step at which the model was logged"""
+ return self._step
diff --git a/mlflow/entities/model_param.py b/mlflow/entities/model_param.py
new file mode 100644
index 0000000000000..b5e4cf7fe8c65
--- /dev/null
+++ b/mlflow/entities/model_param.py
@@ -0,0 +1,38 @@
+import sys
+
+from mlflow.entities._mlflow_object import _MlflowObject
+
+
+class ModelParam(_MlflowObject):
+ """
+ MLflow entity representing a parameter of a Model.
+ """
+
+ def __init__(self, key, value):
+ if "pyspark.ml" in sys.modules:
+ import pyspark.ml.param
+
+ if isinstance(key, pyspark.ml.param.Param):
+ key = key.name
+ value = str(value)
+ self._key = key
+ self._value = value
+
+ @property
+ def key(self):
+ """String key corresponding to the parameter name."""
+ return self._key
+
+ @property
+ def value(self):
+ """String value of the parameter."""
+ return self._value
+
+ def __eq__(self, __o):
+ if isinstance(__o, self.__class__):
+ return self._key == __o._key
+
+ return False
+
+ def __hash__(self):
+ return hash(self._key)
diff --git a/mlflow/entities/model_registry/model_version.py b/mlflow/entities/model_registry/model_version.py
index a30d022c9eb0f..e47f9c8deb1de 100644
--- a/mlflow/entities/model_registry/model_version.py
+++ b/mlflow/entities/model_registry/model_version.py
@@ -1,3 +1,7 @@
+from typing import Dict, List, Optional
+
+from mlflow.entities.metric import Metric
+from mlflow.entities.model_param import ModelParam
from mlflow.entities.model_registry._model_registry_entity import _ModelRegistryEntity
from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
from mlflow.entities.model_registry.model_version_tag import ModelVersionTag
@@ -12,140 +16,163 @@ class ModelVersion(_ModelRegistryEntity):
def __init__(
self,
- name,
- version,
- creation_timestamp,
- last_updated_timestamp=None,
- description=None,
- user_id=None,
- current_stage=None,
- source=None,
- run_id=None,
- status=ModelVersionStatus.to_string(ModelVersionStatus.READY),
- status_message=None,
- tags=None,
- run_link=None,
- aliases=None,
+ name: str,
+ version: str,
+ creation_timestamp: int,
+ last_updated_timestamp: Optional[int] = None,
+ description: Optional[str] = None,
+ user_id: Optional[str] = None,
+ current_stage: Optional[str] = None,
+ source: Optional[str] = None,
+ run_id: Optional[str] = None,
+ status: str = ModelVersionStatus.to_string(ModelVersionStatus.READY),
+ status_message: Optional[str] = None,
+ tags: Optional[List[ModelVersionTag]] = None,
+ run_link: Optional[str] = None,
+ aliases: Optional[List[str]] = None,
+ # TODO: Make model_id a required field
+ # (currently optional to minimize breakages during prototype development)
+ model_id: Optional[str] = None,
+ params: Optional[List[ModelParam]] = None,
+ metrics: Optional[List[Metric]] = None,
):
super().__init__()
- self._name = name
- self._version = version
- self._creation_time = creation_timestamp
- self._last_updated_timestamp = last_updated_timestamp
- self._description = description
- self._user_id = user_id
- self._current_stage = current_stage
- self._source = source
- self._run_id = run_id
- self._run_link = run_link
- self._status = status
- self._status_message = status_message
- self._tags = {tag.key: tag.value for tag in (tags or [])}
- self._aliases = aliases or []
-
- @property
- def name(self):
+ self._name: str = name
+ self._version: str = version
+ self._creation_time: int = creation_timestamp
+ self._last_updated_timestamp: Optional[int] = last_updated_timestamp
+ self._description: Optional[str] = description
+ self._user_id: Optional[str] = user_id
+ self._current_stage: Optional[str] = current_stage
+ self._source: Optional[str] = source
+ self._run_id: Optional[str] = run_id
+ self._run_link: Optional[str] = run_link
+ self._status: str = status
+ self._status_message: Optional[str] = status_message
+ self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])}
+ self._aliases: List[str] = aliases or []
+ self._model_id: Optional[str] = model_id
+ self._params: Optional[List[ModelParam]] = params
+ self._metrics: Optional[List[Metric]] = metrics
+
+ @property
+ def name(self) -> str:
"""String. Unique name within Model Registry."""
return self._name
@name.setter
- def name(self, new_name):
+ def name(self, new_name: str):
self._name = new_name
@property
- def version(self):
- """version"""
+ def version(self) -> str:
+ """Version"""
return self._version
@property
- def creation_timestamp(self):
+ def creation_timestamp(self) -> int:
"""Integer. Model version creation timestamp (milliseconds since the Unix epoch)."""
return self._creation_time
@property
- def last_updated_timestamp(self):
+ def last_updated_timestamp(self) -> Optional[int]:
"""Integer. Timestamp of last update for this model version (milliseconds since the Unix
epoch).
"""
return self._last_updated_timestamp
@last_updated_timestamp.setter
- def last_updated_timestamp(self, updated_timestamp):
+ def last_updated_timestamp(self, updated_timestamp: int):
self._last_updated_timestamp = updated_timestamp
@property
- def description(self):
+ def description(self) -> Optional[str]:
"""String. Description"""
return self._description
@description.setter
- def description(self, description):
+ def description(self, description: str):
self._description = description
@property
- def user_id(self):
+ def user_id(self) -> Optional[str]:
"""String. User ID that created this model version."""
return self._user_id
@property
- def current_stage(self):
+ def current_stage(self) -> Optional[str]:
"""String. Current stage of this model version."""
return self._current_stage
@current_stage.setter
- def current_stage(self, stage):
+ def current_stage(self, stage: str):
self._current_stage = stage
@property
- def source(self):
+ def source(self) -> Optional[str]:
"""String. Source path for the model."""
return self._source
@property
- def run_id(self):
+ def run_id(self) -> Optional[str]:
"""String. MLflow run ID that generated this model."""
return self._run_id
@property
- def run_link(self):
+ def run_link(self) -> Optional[str]:
"""String. MLflow run link referring to the exact run that generated this model version."""
return self._run_link
@property
- def status(self):
+ def status(self) -> str:
"""String. Current Model Registry status for this model."""
return self._status
@property
- def status_message(self):
+ def status_message(self) -> Optional[str]:
"""String. Descriptive message for error status conditions."""
return self._status_message
@property
- def tags(self):
+ def tags(self) -> Dict[str, str]:
"""Dictionary of tag key (string) -> tag value for the current model version."""
return self._tags
@property
- def aliases(self):
+ def aliases(self) -> List[str]:
"""List of aliases (string) for the current model version."""
return self._aliases
@aliases.setter
- def aliases(self, aliases):
+ def aliases(self, aliases: List[str]):
self._aliases = aliases
+ @property
+ def model_id(self) -> Optional[str]:
+ """String. ID of the model associated with this version."""
+ return self._model_id
+
+ @property
+ def params(self) -> Optional[List[ModelParam]]:
+ """List of parameters associated with this model version."""
+ return self._params
+
+ @property
+ def metrics(self) -> Optional[List[Metric]]:
+ """List of metrics associated with this model version."""
+ return self._metrics
+
@classmethod
- def _properties(cls):
+ def _properties(cls) -> List[str]:
# aggregate with base class properties since cls.__dict__ does not do it automatically
return sorted(cls._get_properties_helper())
- def _add_tag(self, tag):
+ def _add_tag(self, tag: ModelVersionTag):
self._tags[tag.key] = tag.value
# proto mappers
@classmethod
- def from_proto(cls, proto):
+ def from_proto(cls, proto: ProtoModelVersion) -> "ModelVersion":
# input: mlflow.protos.model_registry_pb2.ModelVersion
# returns: ModelVersion entity
model_version = cls(
@@ -165,9 +192,10 @@ def from_proto(cls, proto):
)
for tag in proto.tags:
model_version._add_tag(ModelVersionTag.from_proto(tag))
+ # TODO: Include params, metrics, and model ID in proto
return model_version
- def to_proto(self):
+ def to_proto(self) -> ProtoModelVersion:
# input: ModelVersion entity
# returns mlflow.protos.model_registry_pb2.ModelVersion
model_version = ProtoModelVersion()
@@ -196,4 +224,5 @@ def to_proto(self):
[ProtoModelVersionTag(key=key, value=value) for key, value in self._tags.items()]
)
model_version.aliases.extend(self.aliases)
+ # TODO: Include params, metrics, and model ID in proto
return model_version
diff --git a/mlflow/entities/model_status.py b/mlflow/entities/model_status.py
new file mode 100644
index 0000000000000..495eb0638022a
--- /dev/null
+++ b/mlflow/entities/model_status.py
@@ -0,0 +1,9 @@
+from enum import Enum
+
+
+class ModelStatus(str, Enum):
+ """Enum for status of an :py:class:`mlflow.entities.Model`."""
+
+ PENDING = "PENDING"
+ READY = "READY"
+ FAILED = "FAILED"
diff --git a/mlflow/entities/model_tag.py b/mlflow/entities/model_tag.py
new file mode 100644
index 0000000000000..0774ce27759b1
--- /dev/null
+++ b/mlflow/entities/model_tag.py
@@ -0,0 +1,25 @@
+from mlflow.entities._mlflow_object import _MlflowObject
+
+
+class ModelTag(_MlflowObject):
+ """Tag object associated with a Model."""
+
+ def __init__(self, key, value):
+ self._key = key
+ self._value = value
+
+ def __eq__(self, other):
+ if type(other) is type(self):
+ # TODO deep equality here?
+ return self.__dict__ == other.__dict__
+ return False
+
+ @property
+ def key(self):
+ """String name of the tag."""
+ return self._key
+
+ @property
+ def value(self):
+ """String value of the tag."""
+ return self._value
diff --git a/mlflow/entities/run.py b/mlflow/entities/run.py
index 0fade6fa8daa2..8bdc68fe8d895 100644
--- a/mlflow/entities/run.py
+++ b/mlflow/entities/run.py
@@ -4,6 +4,7 @@
from mlflow.entities.run_data import RunData
from mlflow.entities.run_info import RunInfo
from mlflow.entities.run_inputs import RunInputs
+from mlflow.entities.run_outputs import RunOutputs
from mlflow.exceptions import MlflowException
from mlflow.protos.service_pb2 import Run as ProtoRun
@@ -14,13 +15,18 @@ class Run(_MlflowObject):
"""
def __init__(
- self, run_info: RunInfo, run_data: RunData, run_inputs: Optional[RunInputs] = None
+ self,
+ run_info: RunInfo,
+ run_data: RunData,
+ run_inputs: Optional[RunInputs] = None,
+ run_outputs: Optional[RunOutputs] = None,
) -> None:
if run_info is None:
raise MlflowException("run_info cannot be None")
self._info = run_info
self._data = run_data
self._inputs = run_inputs
+ self._outputs = run_outputs
@property
def info(self) -> RunInfo:
@@ -43,12 +49,21 @@ def data(self) -> RunData:
@property
def inputs(self) -> RunInputs:
"""
- The run inputs, including dataset inputs
+ The run inputs, including dataset inputs.
:rtype: :py:class:`mlflow.entities.RunInputs`
"""
return self._inputs
+ @property
+ def outputs(self) -> RunOutputs:
+ """
+ The run outputs, including model outputs.
+
+ :rtype: :py:class:`mlflow.entities.RunOutputs`
+ """
+ return self._outputs
+
def to_proto(self):
run = ProtoRun()
run.info.MergeFrom(self.info.to_proto())
@@ -56,6 +71,9 @@ def to_proto(self):
run.data.MergeFrom(self.data.to_proto())
if self.inputs:
run.inputs.MergeFrom(self.inputs.to_proto())
+ # TODO: Support proto conversion for RunOutputs
+ # if self.outputs:
+ # run.outputs.MergeFrom(self.outputs.to_proto())
return run
@classmethod
@@ -63,7 +81,9 @@ def from_proto(cls, proto):
return cls(
RunInfo.from_proto(proto.info),
RunData.from_proto(proto.data),
- RunInputs.from_proto(proto.inputs),
+ RunInputs.from_proto(proto.inputs) if proto.inputs else None,
+ # TODO: Support proto conversion for RunOutputs
+ # RunOutputs.from_proto(proto.outputs) if proto.outputs else None,
)
def to_dictionary(self) -> Dict[Any, Any]:
@@ -74,4 +94,6 @@ def to_dictionary(self) -> Dict[Any, Any]:
run_dict["data"] = self.data.to_dictionary()
if self.inputs:
run_dict["inputs"] = self.inputs.to_dictionary()
+ if self.outputs:
+ run_dict["outputs"] = self.outputs.to_dictionary()
return run_dict
diff --git a/mlflow/entities/run_inputs.py b/mlflow/entities/run_inputs.py
index d28f026c71bc3..e5b8f1ccc7c7f 100644
--- a/mlflow/entities/run_inputs.py
+++ b/mlflow/entities/run_inputs.py
@@ -2,14 +2,16 @@
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.dataset_input import DatasetInput
+from mlflow.entities.model_input import ModelInput
from mlflow.protos.service_pb2 import RunInputs as ProtoRunInputs
class RunInputs(_MlflowObject):
"""RunInputs object."""
- def __init__(self, dataset_inputs: List[DatasetInput]) -> None:
+ def __init__(self, dataset_inputs: List[DatasetInput], model_inputs: List[ModelInput]) -> None:
self._dataset_inputs = dataset_inputs
+ self._model_inputs = model_inputs
def __eq__(self, other: _MlflowObject) -> bool:
if type(other) is type(self):
@@ -21,16 +23,26 @@ def dataset_inputs(self) -> List[DatasetInput]:
"""Array of dataset inputs."""
return self._dataset_inputs
+ @property
+ def model_inputs(self) -> List[ModelInput]:
+ """Array of model inputs."""
+ return self._model_inputs
+
def to_proto(self):
run_inputs = ProtoRunInputs()
run_inputs.dataset_inputs.extend(
[dataset_input.to_proto() for dataset_input in self.dataset_inputs]
)
+ # TODO: Support proto conversion for model inputs
+ # run_inputs.model_inputs.extend(
+ # [model_input.to_proto() for model_input in self.model_inputs]
+ # )
return run_inputs
def to_dictionary(self) -> Dict[Any, Any]:
return {
"dataset_inputs": self.dataset_inputs,
+ "model_inputs": self.model_inputs,
}
@classmethod
@@ -38,4 +50,8 @@ def from_proto(cls, proto):
dataset_inputs = [
DatasetInput.from_proto(dataset_input) for dataset_input in proto.dataset_inputs
]
- return cls(dataset_inputs)
+ # TODO: Support proto conversion for model inputs
+ # model_inputs = [
+ # ModelInput.from_proto(model_input) for model_input in proto.model_inputs
+ # ]
+ return cls(dataset_inputs, [])
diff --git a/mlflow/entities/run_outputs.py b/mlflow/entities/run_outputs.py
new file mode 100644
index 0000000000000..3d4b8a5f83b77
--- /dev/null
+++ b/mlflow/entities/run_outputs.py
@@ -0,0 +1,26 @@
+from typing import Any, Dict, List
+
+from mlflow.entities._mlflow_object import _MlflowObject
+from mlflow.entities.model_output import ModelOutput
+
+
+class RunOutputs(_MlflowObject):
+ """RunOutputs object."""
+
+ def __init__(self, model_outputs: List[ModelOutput]) -> None:
+ self._model_outputs = model_outputs
+
+ def __eq__(self, other: _MlflowObject) -> bool:
+ if type(other) is type(self):
+ return self.__dict__ == other.__dict__
+ return False
+
+ @property
+ def model_outputs(self) -> List[ModelOutput]:
+ """Array of model outputs."""
+ return self._model_outputs
+
+ def to_dictionary(self) -> Dict[Any, Any]:
+ return {
+ "model_outputs": self.model_outputs,
+ }
diff --git a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java
index 487c8f9a2864e..4139fe8cab66d 100644
--- a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java
+++ b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java
@@ -32,6 +32,10 @@ public enum InputVertexType
* DATASET = 2;
*/
DATASET(2),
+ /**
+ * MODEL = 3;
+ */
+ MODEL(3),
;
/**
@@ -42,6 +46,10 @@ public enum InputVertexType
* DATASET = 2;
*/
public static final int DATASET_VALUE = 2;
+ /**
+ * MODEL = 3;
+ */
+ public static final int MODEL_VALUE = 3;
public final int getNumber() {
@@ -66,6 +74,7 @@ public static InputVertexType forNumber(int value) {
switch (value) {
case 1: return RUN;
case 2: return DATASET;
+ case 3: return MODEL;
default: return null;
}
}
@@ -115,6 +124,107 @@ private InputVertexType(int value) {
// @@protoc_insertion_point(enum_scope:mlflow.internal.InputVertexType)
}
+ /**
+ *
+ * Types of vertices represented in MLflow Run Outputs. Valid vertices are MLflow objects that can
+ * have an output relationship.
+ *
+ *
+ * Protobuf enum {@code mlflow.internal.OutputVertexType}
+ */
+ public enum OutputVertexType
+ implements com.google.protobuf.ProtocolMessageEnum {
+ /**
+ * RUN_OUTPUT = 1;
+ */
+ RUN_OUTPUT(1),
+ /**
+ * MODEL_OUTPUT = 2;
+ */
+ MODEL_OUTPUT(2),
+ ;
+
+ /**
+ * RUN_OUTPUT = 1;
+ */
+ public static final int RUN_OUTPUT_VALUE = 1;
+ /**
+ * MODEL_OUTPUT = 2;
+ */
+ public static final int MODEL_OUTPUT_VALUE = 2;
+
+
+ public final int getNumber() {
+ return value;
+ }
+
+ /**
+ * @param value The numeric wire value of the corresponding enum entry.
+ * @return The enum associated with the given numeric wire value.
+ * @deprecated Use {@link #forNumber(int)} instead.
+ */
+ @java.lang.Deprecated
+ public static OutputVertexType valueOf(int value) {
+ return forNumber(value);
+ }
+
+ /**
+ * @param value The numeric wire value of the corresponding enum entry.
+ * @return The enum associated with the given numeric wire value.
+ */
+ public static OutputVertexType forNumber(int value) {
+ switch (value) {
+ case 1: return RUN_OUTPUT;
+ case 2: return MODEL_OUTPUT;
+ default: return null;
+ }
+ }
+
+ public static com.google.protobuf.Internal.EnumLiteMap
+ internalGetValueMap() {
+ return internalValueMap;
+ }
+ private static final com.google.protobuf.Internal.EnumLiteMap<
+ OutputVertexType> internalValueMap =
+ new com.google.protobuf.Internal.EnumLiteMap() {
+ public OutputVertexType findValueByNumber(int number) {
+ return OutputVertexType.forNumber(number);
+ }
+ };
+
+ public final com.google.protobuf.Descriptors.EnumValueDescriptor
+ getValueDescriptor() {
+ return getDescriptor().getValues().get(ordinal());
+ }
+ public final com.google.protobuf.Descriptors.EnumDescriptor
+ getDescriptorForType() {
+ return getDescriptor();
+ }
+ public static final com.google.protobuf.Descriptors.EnumDescriptor
+ getDescriptor() {
+ return org.mlflow.internal.proto.Internal.getDescriptor().getEnumTypes().get(1);
+ }
+
+ private static final OutputVertexType[] VALUES = values();
+
+ public static OutputVertexType valueOf(
+ com.google.protobuf.Descriptors.EnumValueDescriptor desc) {
+ if (desc.getType() != getDescriptor()) {
+ throw new java.lang.IllegalArgumentException(
+ "EnumValueDescriptor is not for this type.");
+ }
+ return VALUES[desc.getIndex()];
+ }
+
+ private final int value;
+
+ private OutputVertexType(int value) {
+ this.value = value;
+ }
+
+ // @@protoc_insertion_point(enum_scope:mlflow.internal.OutputVertexType)
+ }
+
public static com.google.protobuf.Descriptors.FileDescriptor
getDescriptor() {
@@ -125,9 +235,10 @@ private InputVertexType(int value) {
static {
java.lang.String[] descriptorData = {
"\n\016internal.proto\022\017mlflow.internal\032\025scala" +
- "pb/scalapb.proto*\'\n\017InputVertexType\022\007\n\003R" +
- "UN\020\001\022\013\n\007DATASET\020\002B#\n\031org.mlflow.internal" +
- ".proto\220\001\001\342?\002\020\001"
+ "pb/scalapb.proto*2\n\017InputVertexType\022\007\n\003R" +
+ "UN\020\001\022\013\n\007DATASET\020\002\022\t\n\005MODEL\020\003*4\n\020OutputVe" +
+ "rtexType\022\016\n\nRUN_OUTPUT\020\001\022\020\n\014MODEL_OUTPUT" +
+ "\020\002B#\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001"
};
descriptor = com.google.protobuf.Descriptors.FileDescriptor
.internalBuildGeneratedFileFrom(descriptorData,
diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py
index 688bba77b2301..23045b1ce97b1 100644
--- a/mlflow/langchain/__init__.py
+++ b/mlflow/langchain/__init__.py
@@ -406,7 +406,7 @@ def load_retriever(persist_directory):
@trace_disabled # Suppress traces for internal predict calls while logging model
def log_model(
lc_model,
- artifact_path,
+ name: Optional[str] = None,
conda_env=None,
code_paths=None,
registered_model_name=None,
@@ -422,6 +422,11 @@ def log_model(
run_id=None,
model_config=None,
streamable=None,
+ params: Optional[Dict[str, Any]] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ model_type: Optional[str] = None,
+ step: int = 0,
+ model_id: Optional[str] = None,
):
"""
Log a LangChain model as an MLflow artifact for the current run.
@@ -437,7 +442,7 @@ def log_model(
.. Note:: Experimental: Using model as path may change or be removed in a future
release without warning.
- artifact_path: Run-relative artifact path.
+ name: The name of the model.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
registered_model_name: This argument may change or be removed in a
@@ -547,7 +552,7 @@ def load_retriever(persist_directory):
metadata of the logged model.
"""
return Model.log(
- artifact_path=artifact_path,
+ name=name,
flavor=mlflow.langchain,
registered_model_name=registered_model_name,
lc_model=lc_model,
@@ -565,6 +570,11 @@ def load_retriever(persist_directory):
run_id=run_id,
model_config=model_config,
streamable=streamable,
+ params=params,
+ tags=tags,
+ model_type=model_type,
+ step=step,
+ model_id=model_id,
)
diff --git a/mlflow/models/model.py b/mlflow/models/model.py
index 170e89f5f2a57..2f9d807892ae5 100644
--- a/mlflow/models/model.py
+++ b/mlflow/models/model.py
@@ -14,6 +14,7 @@
import mlflow
from mlflow.artifacts import download_artifacts
+from mlflow.entities import Metric, ModelOutput, ModelStatus
from mlflow.exceptions import MlflowException
from mlflow.models.resources import Resource, ResourceType, _ResourceBuilder
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
@@ -654,13 +655,18 @@ def from_dict(cls, model_dict):
@classmethod
def log(
cls,
- artifact_path,
+ name,
flavor,
registered_model_name=None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
metadata=None,
run_id=None,
resources=None,
+ model_type: Optional[str] = None,
+ params: Optional[Dict[str, Any]] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ step: int = 0,
+ model_id: Optional[str] = None,
**kwargs,
):
"""
@@ -668,7 +674,7 @@ def log(
active run.
Args:
- artifact_path: Run relative path identifying the model.
+ name: The name of the model.
flavor: Flavor module to save the model with. The module must have
the ``save_model`` function that will persist the model as a valid
MLflow model.
@@ -689,15 +695,79 @@ def log(
A :py:class:`ModelInfo ` instance that contains the
metadata of the logged model.
"""
- from mlflow.utils.model_utils import _validate_and_get_model_config_from_file
+ if (model_id, name).count(None) == 2:
+ raise MlflowException(
+ "Either `model_id` or `name` must be specified when logging a model. "
+ "Both are None.",
+ error_code=INVALID_PARAMETER_VALUE,
+ )
+
+ def log_model_metrics_for_step(client, model_id, run_id, step):
+ metric_names = client.get_run(run_id).data.metrics.keys()
+ metrics_for_step = []
+ for metric_name in metric_names:
+ history = client.get_metric_history(run_id, metric_name)
+ metrics_for_step.extend(
+ [
+ Metric(
+ key=metric.key,
+ value=metric.value,
+ timestamp=metric.timestamp,
+ step=metric.step,
+ dataset_name=metric.dataset_name,
+ dataset_digest=metric.dataset_digest,
+ run_id=metric.run_id,
+ model_id=model_id,
+ )
+ for metric in history
+ if metric.step == step and metric.model_id is None
+ ]
+ )
+ client.log_batch(run_id=run_id, metrics=metrics_for_step)
registered_model = None
with TempDir() as tmp:
local_path = tmp.path("model")
- if run_id is None:
- run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
+
+ tracking_uri = _resolve_tracking_uri()
+ client = mlflow.MlflowClient(tracking_uri)
+ active_run = mlflow.tracking.fluent.active_run()
+ if model_id is not None:
+ model = client.get_logged_model(model_id)
+ else:
+ params = {
+ **(params or {}),
+ **(
+ client.get_run(active_run.info.run_id).data.params
+ if active_run is not None
+ else {}
+ ),
+ }
+ model = client.create_logged_model(
+ experiment_id=mlflow.tracking.fluent._get_experiment_id(),
+ # TODO: Update model name
+ name=name,
+ run_id=active_run.info.run_id if active_run is not None else None,
+ model_type=model_type,
+ params={key: str(value) for key, value in params.items()},
+ tags={key: str(value) for key, value in tags.items()}
+ if tags is not None
+ else None,
+ )
+
+ if active_run is not None:
+ run_id = active_run.info.run_id
+ client.log_outputs(run_id=run_id, models=[ModelOutput(model.model_id, step=step)])
+ log_model_metrics_for_step(
+ client=client, model_id=model.model_id, run_id=run_id, step=step
+ )
+
mlflow_model = cls(
- artifact_path=artifact_path, run_id=run_id, metadata=metadata, resources=resources
+ artifact_path=model.artifact_location,
+ model_uuid=model.model_id,
+ run_id=active_run.info.run_id if active_run is not None else None,
+ metadata=metadata,
+ resources=resources,
)
flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
# `save_model` calls `load_model` to infer the model requirements, which may result in
@@ -708,7 +778,6 @@ def log(
if is_in_databricks_runtime():
_copy_model_metadata_for_uc_sharing(local_path, flavor)
- tracking_uri = _resolve_tracking_uri()
serving_input = mlflow_model.get_serving_input(local_path)
# We check signature presence here as some flavors have a default signature as a
# fallback when not provided by user, which is set during flavor's save_model() call.
@@ -717,66 +786,60 @@ def log(
_logger.warning(_LOG_MODEL_MISSING_INPUT_EXAMPLE_WARNING)
elif tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks":
_logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING)
- mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id)
-
- # if the model_config kwarg is passed in, then log the model config as an params
- if model_config := kwargs.get("model_config"):
- if isinstance(model_config, str):
- try:
- file_extension = os.path.splitext(model_config)[1].lower()
- if file_extension == ".json":
- with open(model_config) as f:
- model_config = json.load(f)
- elif file_extension in [".yaml", ".yml"]:
- model_config = _validate_and_get_model_config_from_file(model_config)
- else:
- _logger.warning(
- "Unsupported file format for model config: %s. "
- "Failed to load model config.",
- model_config,
- )
- except Exception as e:
- _logger.warning("Failed to load model config from %s: %s", model_config, e)
- try:
- from mlflow.models.utils import _flatten_nested_params
-
- # We are using the `/` separator to flatten the nested params
- # since we are using the same separator to log nested metrics.
- params_to_log = _flatten_nested_params(model_config, sep="/")
- except Exception as e:
- _logger.warning("Failed to flatten nested params: %s", str(e))
- params_to_log = model_config
-
- try:
- mlflow.tracking.fluent.log_params(params_to_log or {}, run_id=run_id)
- except Exception as e:
- _logger.warning("Failed to log model config as params: %s", str(e))
-
- try:
- mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id)
- except MlflowException:
- # We need to swallow all mlflow exceptions to maintain backwards compatibility with
- # older tracking servers. Only print out a warning for now.
- _logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri())
- _logger.debug("", exc_info=True)
-
- if registered_model_name is not None:
- registered_model = mlflow.tracking._model_registry.fluent._register_model(
- f"runs:/{run_id}/{mlflow_model.artifact_path}",
- registered_model_name,
- await_registration_for=await_registration_for,
- local_model_path=local_path,
- )
- model_info = mlflow_model.get_model_info()
- if registered_model is not None:
- model_info.registered_model_version = registered_model.version
+ client.log_model_artifacts(model.model_id, local_path)
+ client.finalize_logged_model(model.model_id, status=ModelStatus.READY)
+
+ # # if the model_config kwarg is passed in, then log the model config as an params
+ # if model_config := kwargs.get("model_config"):
+ # if isinstance(model_config, str):
+ # try:
+ # file_extension = os.path.splitext(model_config)[1].lower()
+ # if file_extension == ".json":
+ # with open(model_config) as f:
+ # model_config = json.load(f)
+ # elif file_extension in [".yaml", ".yml"]:
+ # model_config = _validate_and_get_model_config_from_file(model_config)
+ # else:
+ # _logger.warning(
+ # "Unsupported file format for model config: %s. "
+ # "Failed to load model config.",
+ # model_config,
+ # )
+ # except Exception as e:
+ # _logger.warning(
+ # "Failed to load model config from %s: %s", model_config, e
+ # )
+ #
+ # try:
+ # from mlflow.models.utils import _flatten_nested_params
+ #
+ # # We are using the `/` separator to flatten the nested params
+ # # since we are using the same separator to log nested metrics.
+ # params_to_log = _flatten_nested_params(model_config, sep="/")
+ # except Exception as e:
+ # _logger.warning("Failed to flatten nested params: %s", str(e))
+ # params_to_log = model_config
+ #
+ # try:
+ # mlflow.tracking.fluent.log_params(params_to_log or {}, run_id=run_id)
+ # except Exception as e:
+ # _logger.warning("Failed to log model config as params: %s", str(e))
+ #
+ # try:
+ # mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id)
+ # except MlflowException:
+ # # We need to swallow all mlflow exceptions to maintain backwards compatibility
+ # # with older tracking servers. Only print out a warning for now.
+ # _logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri())
+ # _logger.debug("", exc_info=True)
# validate input example works for serving when logging the model
if serving_input:
from mlflow.models import validate_serving_input
try:
+ model_info = mlflow_model.get_model_info()
validate_serving_input(model_info.model_uri, serving_input)
except Exception as e:
_logger.warning(
@@ -792,7 +855,21 @@ def log(
exc_info=_logger.isEnabledFor(logging.DEBUG),
)
- return model_info
+ if registered_model_name is not None:
+ registered_model = mlflow.tracking._model_registry.fluent._register_model(
+ f"models:/{model.model_id}",
+ registered_model_name,
+ await_registration_for=await_registration_for,
+ local_model_path=local_path,
+ )
+ return client.get_model_version(registered_model_name, registered_model.version)
+ else:
+ return client.get_logged_model(model.model_id)
+ # model_info = mlflow_model.get_model_info()
+ # if registered_model is not None:
+ # model_info.registered_model_version = registered_model.version
+
+ # return model_info
def _copy_model_metadata_for_uc_sharing(local_path, flavor):
diff --git a/mlflow/protos/internal.proto b/mlflow/protos/internal.proto
index 614a1916c1415..057e18fea1c9c 100644
--- a/mlflow/protos/internal.proto
+++ b/mlflow/protos/internal.proto
@@ -20,4 +20,14 @@ enum InputVertexType {
RUN = 1;
DATASET = 2;
+
+ MODEL = 3;
+}
+
+// Types of vertices represented in MLflow Run Outputs. Valid vertices are MLflow objects that can
+// have an output relationship.
+enum OutputVertexType {
+ RUN_OUTPUT = 1;
+
+ MODEL_OUTPUT = 2;
}
diff --git a/mlflow/protos/internal_pb2.py b/mlflow/protos/internal_pb2.py
index 7fcf97a79455a..7aa93249e6ff5 100644
--- a/mlflow/protos/internal_pb2.py
+++ b/mlflow/protos/internal_pb2.py
@@ -19,7 +19,7 @@
from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*\'\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01')
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03*4\n\x10OutputVertexType\x12\x0e\n\nRUN_OUTPUT\x10\x01\x12\x10\n\x0cMODEL_OUTPUT\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -28,7 +28,9 @@
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001'
_globals['_INPUTVERTEXTYPE']._serialized_start=58
- _globals['_INPUTVERTEXTYPE']._serialized_end=97
+ _globals['_INPUTVERTEXTYPE']._serialized_end=108
+ _globals['_OUTPUTVERTEXTYPE']._serialized_start=110
+ _globals['_OUTPUTVERTEXTYPE']._serialized_end=162
# @@protoc_insertion_point(module_scope)
else:
@@ -50,12 +52,17 @@
from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*\'\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01')
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03*4\n\x10OutputVertexType\x12\x0e\n\nRUN_OUTPUT\x10\x01\x12\x10\n\x0cMODEL_OUTPUT\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01')
_INPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['InputVertexType']
InputVertexType = enum_type_wrapper.EnumTypeWrapper(_INPUTVERTEXTYPE)
+ _OUTPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['OutputVertexType']
+ OutputVertexType = enum_type_wrapper.EnumTypeWrapper(_OUTPUTVERTEXTYPE)
RUN = 1
DATASET = 2
+ MODEL = 3
+ RUN_OUTPUT = 1
+ MODEL_OUTPUT = 2
if _descriptor._USE_C_DESCRIPTORS == False:
@@ -63,6 +70,8 @@
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001'
_INPUTVERTEXTYPE._serialized_start=58
- _INPUTVERTEXTYPE._serialized_end=97
+ _INPUTVERTEXTYPE._serialized_end=108
+ _OUTPUTVERTEXTYPE._serialized_start=110
+ _OUTPUTVERTEXTYPE._serialized_end=162
# @@protoc_insertion_point(module_scope)
diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py
index 2c79912cf47ce..0f47db32a67ad 100644
--- a/mlflow/pyfunc/__init__.py
+++ b/mlflow/pyfunc/__init__.py
@@ -2633,7 +2633,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
@trace_disabled # Suppress traces for internal predict calls while logging model
def log_model(
- artifact_path,
+ name=None,
loader_module=None,
data_path=None,
code_path=None, # deprecated
@@ -2653,6 +2653,11 @@ def log_model(
example_no_conversion=None,
streamable=None,
resources: Optional[Union[str, List[Resource]]] = None,
+ params: Optional[Dict[str, Any]] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ model_type: Optional[str] = None,
+ step: int = 0,
+ model_id: Optional[str] = None,
):
"""
Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow
@@ -2665,7 +2670,7 @@ def log_model(
and the parameters for the first workflow: ``python_model``, ``artifacts`` together.
Args:
- artifact_path: The run-relative artifact path to which to log the Python model.
+ name: The name of the model.
loader_module: The name of the Python module that is used to load the model
from ``data_path``. This module must define a method with the prototype
``_load_pyfunc(data_path)``. If not ``None``, this module and its
@@ -2852,7 +2857,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:
metadata of the logged model.
"""
return Model.log(
- artifact_path=artifact_path,
+ name=name,
flavor=mlflow.pyfunc,
loader_module=loader_module,
data_path=data_path,
@@ -2873,6 +2878,11 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:
streamable=streamable,
resources=resources,
infer_code_paths=infer_code_paths,
+ params=params,
+ tags=tags,
+ model_type=model_type,
+ step=step,
+ model_id=model_id,
)
diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py
index 198ac1ff6c609..02c453745d3d5 100644
--- a/mlflow/pytorch/__init__.py
+++ b/mlflow/pytorch/__init__.py
@@ -137,7 +137,7 @@ def get_default_conda_env():
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="torch"))
def log_model(
pytorch_model,
- artifact_path,
+ name: Optional[str] = None,
conda_env=None,
code_paths=None,
pickle_module=None,
@@ -150,6 +150,11 @@ def log_model(
pip_requirements=None,
extra_pip_requirements=None,
metadata=None,
+ params: Optional[Dict[str, Any]] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ model_type: Optional[str] = None,
+ step: int = 0,
+ model_id: Optional[str] = None,
**kwargs,
):
"""
@@ -177,7 +182,7 @@ class definition itself, should be included in one of the following locations:
``conda_env`` parameter.
- One or more of the files specified by the ``code_paths`` parameter.
- artifact_path: Run-relative artifact path.
+ name: The name of the model.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
pickle_module: The module that PyTorch should use to serialize ("pickle") the specified
@@ -293,7 +298,7 @@ class definition itself, should be included in one of the following locations:
"""
pickle_module = pickle_module or mlflow_pytorch_pickle_module
return Model.log(
- artifact_path=artifact_path,
+ name=name,
flavor=mlflow.pytorch,
pytorch_model=pytorch_model,
conda_env=conda_env,
@@ -308,6 +313,11 @@ class definition itself, should be included in one of the following locations:
pip_requirements=pip_requirements,
extra_pip_requirements=extra_pip_requirements,
metadata=metadata,
+ params=params,
+ tags=tags,
+ model_type=model_type,
+ step=step,
+ model_id=model_id,
**kwargs,
)
diff --git a/mlflow/sklearn/__init__.py b/mlflow/sklearn/__init__.py
index 39b3cf7f5cc5b..7bc50a5f75b22 100644
--- a/mlflow/sklearn/__init__.py
+++ b/mlflow/sklearn/__init__.py
@@ -333,7 +333,7 @@ def save_model(
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
def log_model(
sk_model,
- artifact_path,
+ name: Optional[str] = None,
conda_env=None,
code_paths=None,
serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
@@ -345,6 +345,11 @@ def log_model(
extra_pip_requirements=None,
pyfunc_predict_fn="predict",
metadata=None,
+ params: Optional[Dict[str, Any]] = None,
+ tags: Optional[Dict[str, Any]] = None,
+ model_type: Optional[str] = None,
+ step: int = 0,
+ model_id: Optional[str] = None,
):
"""
Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model
@@ -356,7 +361,7 @@ def log_model(
Args:
sk_model: scikit-learn model to be saved.
- artifact_path: Run-relative artifact path.
+ name: Model name.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
serialization_format: The format in which to serialize the model. This should be one of
@@ -410,7 +415,7 @@ def log_model(
"""
return Model.log(
- artifact_path=artifact_path,
+ name=name,
flavor=mlflow.sklearn,
sk_model=sk_model,
conda_env=conda_env,
@@ -424,6 +429,11 @@ def log_model(
extra_pip_requirements=extra_pip_requirements,
pyfunc_predict_fn=pyfunc_predict_fn,
metadata=metadata,
+ params=params,
+ tags=tags,
+ model_type=model_type,
+ step=step,
+ model_id=model_id,
)
diff --git a/mlflow/store/artifact/models_artifact_repo.py b/mlflow/store/artifact/models_artifact_repo.py
index f54a062f22716..614743200d147 100644
--- a/mlflow/store/artifact/models_artifact_repo.py
+++ b/mlflow/store/artifact/models_artifact_repo.py
@@ -90,8 +90,15 @@ def _get_model_uri_infos(uri):
get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri()
)
client = MlflowClient(registry_uri=databricks_profile_uri)
- name, version = get_model_name_and_version(client, uri)
- download_uri = client.get_model_version_download_uri(name, version)
+ name_and_version_or_id = get_model_name_and_version(client, uri)
+ if len(name_and_version_or_id) == 1:
+ name = None
+ version = None
+ model_id = name_and_version_or_id[0]
+ download_uri = client.get_logged_model(model_id).artifact_location
+ else:
+ name, version = name_and_version_or_id
+ download_uri = client.get_model_version_download_uri(name, version)
return (
name,
diff --git a/mlflow/store/artifact/utils/models.py b/mlflow/store/artifact/utils/models.py
index e4347328d9e52..9e3cab029fc6b 100644
--- a/mlflow/store/artifact/utils/models.py
+++ b/mlflow/store/artifact/utils/models.py
@@ -37,7 +37,8 @@ def _get_latest_model_version(client, name, stage):
class ParsedModelUri(NamedTuple):
- name: str
+ model_id: Optional[str] = None
+ name: Optional[str] = None
version: Optional[str] = None
stage: Optional[str] = None
alias: Optional[str] = None
@@ -47,6 +48,7 @@ def _parse_model_uri(uri):
"""
Returns a ParsedModelUri tuple. Since a models:/ URI can only have one of
{version, stage, 'latest', alias}, it will return
+ - (id, None, None, None) to look for a specific model by ID,
- (name, version, None, None) to look for a specific version,
- (name, None, stage, None) to look for the latest version of a stage,
- (name, None, None, None) to look for the latest of all versions.
@@ -77,16 +79,21 @@ def _parse_model_uri(uri):
else:
# The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return ParsedModelUri(name, stage=suffix)
- else:
+ elif "@" in path:
# The URI is an alias URI, e.g. "models:/AdsModel1@Champion"
alias_parts = parts[0].rsplit("@", 1)
if len(alias_parts) != 2 or alias_parts[1].strip() == "":
raise MlflowException(_improper_model_uri_msg(uri))
return ParsedModelUri(alias_parts[0], alias=alias_parts[1])
+ else:
+ # The URI is of the form "models:/"
+ return ParsedModelUri(parts[0])
def get_model_name_and_version(client, models_uri):
- (model_name, model_version, model_stage, model_alias) = _parse_model_uri(models_uri)
+ (model_id, model_name, model_version, model_stage, model_alias) = _parse_model_uri(models_uri)
+ if model_id is not None:
+ return (model_id,)
if model_version is not None:
return model_name, model_version
if model_alias is not None:
diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py
index da5869fe5f825..13852e451a59c 100644
--- a/mlflow/store/model_registry/file_store.py
+++ b/mlflow/store/model_registry/file_store.py
@@ -5,7 +5,7 @@
import time
import urllib
from os.path import join
-from typing import List
+from typing import List, Optional
from mlflow.entities.model_registry import (
ModelVersion,
@@ -570,9 +570,23 @@ def _get_model_version_aliases(self, directory):
return [alias.alias for alias in aliases if alias.version == version]
def _get_file_model_version_from_dir(self, directory) -> FileModelVersion:
+ from mlflow.tracking.client import MlflowClient
+
meta = FileStore._read_yaml(directory, FileStore.META_DATA_FILE_NAME)
meta["tags"] = self._get_model_version_tags_from_dir(directory)
meta["aliases"] = self._get_model_version_aliases(directory)
+ # Fetch metrics and params from model ID
+ #
+ # TODO: Propagate tracking URI to file store directly, rather than relying on global
+ # URI (individual MlflowClient instances may have different tracking URIs)
+ if "model_id" in meta:
+ try:
+ model = MlflowClient().get_logged_model(meta["model_id"])
+ meta["metrics"] = model.metrics
+ meta["params"] = model.params
+ except Exception:
+ # TODO: Make this exception handling more specific
+ pass
return FileModelVersion.from_dictionary(meta)
def _save_model_version_as_meta_file(
@@ -605,6 +619,7 @@ def create_model_version(
run_link=None,
description=None,
local_model_path=None,
+ model_id: Optional[str] = None,
) -> ModelVersion:
"""
Create a new model version from given source and run ID.
@@ -617,6 +632,8 @@ def create_model_version(
instances associated with this model version.
run_link: Link to the run from an MLflow tracking server that generated this model.
description: Description of the version.
+ model_id: The ID of the model (from an Experiment) that is being promoted to a
+ registered model version, if applicable.
Returns:
A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
@@ -639,9 +656,19 @@ def next_version(registered_model_name):
if urllib.parse.urlparse(source).scheme == "models":
parsed_model_uri = _parse_model_uri(source)
try:
- storage_location = self.get_model_version_download_uri(
- parsed_model_uri.name, parsed_model_uri.version
- )
+ from mlflow.tracking.client import MlflowClient
+
+ if parsed_model_uri.model_id is not None:
+ # TODO: Propagate tracking URI to file store directly, rather than relying on
+ # global URI (individual MlflowClient instances may have different tracking
+ # URIs)
+ model = MlflowClient().get_logged_model(parsed_model_uri.model_id)
+ storage_location = model.artifact_location
+ run_id = run_id or model.run_id
+ else:
+ storage_location = self.get_model_version_download_uri(
+ parsed_model_uri.name, parsed_model_uri.version
+ )
except Exception as e:
raise MlflowException(
f"Unable to fetch model from model URI source artifact location '{source}'."
@@ -667,6 +694,7 @@ def next_version(registered_model_name):
tags=tags,
aliases=[],
storage_location=storage_location,
+ model_id=model_id,
)
model_version_dir = self._get_model_version_dir(name, version)
mkdir(model_version_dir)
@@ -677,7 +705,7 @@ def next_version(registered_model_name):
if tags is not None:
for tag in tags:
self.set_model_version_tag(name, version, tag)
- return model_version.to_mlflow_entity()
+ return self.get_model_version(name, version)
except Exception as e:
more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1
logging.warning(
diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py
index c2bdd63a09d1a..1a4cf44ac1f8c 100644
--- a/mlflow/store/tracking/file_store.py
+++ b/mlflow/store/tracking/file_store.py
@@ -5,8 +5,9 @@
import sys
import time
import uuid
+from collections import defaultdict
from dataclasses import dataclass
-from typing import Dict, List, NamedTuple, Optional, Tuple
+from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from mlflow.entities import (
Dataset,
@@ -14,12 +15,19 @@
Experiment,
ExperimentTag,
InputTag,
+ LoggedModel,
Metric,
+ ModelInput,
+ ModelOutput,
+ ModelParam,
+ ModelStatus,
+ ModelTag,
Param,
Run,
RunData,
RunInfo,
RunInputs,
+ RunOutputs,
RunStatus,
RunTag,
SourceType,
@@ -38,7 +46,7 @@
INVALID_PARAMETER_VALUE,
RESOURCE_DOES_NOT_EXIST,
)
-from mlflow.protos.internal_pb2 import InputVertexType
+from mlflow.protos.internal_pb2 import InputVertexType, OutputVertexType
from mlflow.store.entities.paged_list import PagedList
from mlflow.store.model_registry.file_store import FileStore as ModelRegistryFileStore
from mlflow.store.tracking import (
@@ -159,6 +167,7 @@ class FileStore(AbstractStore):
EXPERIMENT_TAGS_FOLDER_NAME = "tags"
DATASETS_FOLDER_NAME = "datasets"
INPUTS_FOLDER_NAME = "inputs"
+ OUTPUTS_FOLDER_NAME = "outputs"
META_DATA_FILE_NAME = "meta.yaml"
DEFAULT_EXPERIMENT_ID = "0"
TRACE_INFO_FILE_NAME = "trace_info.yaml"
@@ -170,6 +179,7 @@ class FileStore(AbstractStore):
DATASETS_FOLDER_NAME,
TRACES_FOLDER_NAME,
]
+ MODELS_FOLDER_NAME = "models"
def __init__(self, root_directory=None, artifact_root_uri=None):
"""
@@ -237,6 +247,12 @@ def _get_metric_path(self, experiment_id, run_uuid, metric_key):
self._get_run_dir(experiment_id, run_uuid), FileStore.METRICS_FOLDER_NAME, metric_key
)
+ def _get_model_metric_path(self, experiment_id: str, model_id: str, metric_key: str) -> str:
+ _validate_metric_name(metric_key)
+ return os.path.join(
+ self._get_model_dir(experiment_id, model_id), FileStore.METRICS_FOLDER_NAME, metric_key
+ )
+
def _get_param_path(self, experiment_id, run_uuid, param_name):
_validate_run_id(run_uuid)
_validate_param_name(param_name)
@@ -687,11 +703,12 @@ def _get_run_from_info(self, run_info):
params = self._get_all_params(run_info)
tags = self._get_all_tags(run_info)
inputs: RunInputs = self._get_all_inputs(run_info)
+ outputs: RunOutputs = self._get_all_outputs(run_info)
if not run_info.run_name:
run_name = _get_run_name_from_tags(tags)
if run_name:
run_info._set_run_name(run_name)
- return Run(run_info, RunData(metrics, params, tags), inputs)
+ return Run(run_info, RunData(metrics, params, tags), inputs, outputs)
def _get_run_info(self, run_uuid):
"""
@@ -751,10 +768,12 @@ def _get_resource_files(self, root_dir, subfolder_name):
return source_dirs[0], file_names
@staticmethod
- def _get_metric_from_file(parent_path, metric_name, exp_id):
+ def _get_metric_from_file(
+ parent_path: str, metric_name: str, run_id: str, exp_id: str
+ ) -> Metric:
_validate_metric_name(metric_name)
metric_objs = [
- FileStore._get_metric_from_line(metric_name, line, exp_id)
+ FileStore._get_metric_from_line(run_id, metric_name, line, exp_id)
for line in read_file_lines(parent_path, metric_name)
]
if len(metric_objs) == 0:
@@ -775,24 +794,38 @@ def _get_all_metrics(self, run_info):
metrics = []
for metric_file in metric_files:
metrics.append(
- self._get_metric_from_file(parent_path, metric_file, run_info.experiment_id)
+ self._get_metric_from_file(
+ parent_path, metric_file, run_info.run_id, run_info.experiment_id
+ )
)
return metrics
@staticmethod
- def _get_metric_from_line(metric_name, metric_line, exp_id):
+ def _get_metric_from_line(
+ run_id: str, metric_name: str, metric_line: str, exp_id: str
+ ) -> Metric:
metric_parts = metric_line.strip().split(" ")
- if len(metric_parts) != 2 and len(metric_parts) != 3:
+ if len(metric_parts) != 2 and len(metric_parts) != 3 and len(metric_parts) != 5:
raise MlflowException(
f"Metric '{metric_name}' is malformed; persisted metric data contained "
- f"{len(metric_parts)} fields. Expected 2 or 3 fields. "
+ f"{len(metric_parts)} fields. Expected 2, 3, or 5 fields. "
f"Experiment id: {exp_id}",
databricks_pb2.INTERNAL_ERROR,
)
ts = int(metric_parts[0])
val = float(metric_parts[1])
step = int(metric_parts[2]) if len(metric_parts) == 3 else 0
- return Metric(key=metric_name, value=val, timestamp=ts, step=step)
+ dataset_name = str(metric_parts[3]) if len(metric_parts) == 5 else None
+ dataset_digest = str(metric_parts[4]) if len(metric_parts) == 5 else None
+ return Metric(
+ key=metric_name,
+ value=val,
+ timestamp=ts,
+ step=step,
+ dataset_name=dataset_name,
+ dataset_digest=dataset_digest,
+ run_id=run_id,
+ )
def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None):
"""
@@ -831,7 +864,7 @@ def get_metric_history(self, run_id, metric_key, max_results=None, page_token=No
return PagedList([], None)
return PagedList(
[
- FileStore._get_metric_from_line(metric_key, line, run_info.experiment_id)
+ FileStore._get_metric_from_line(run_id, metric_key, line, run_info.experiment_id)
for line in read_file_lines(parent_path, metric_key)
],
None,
@@ -945,17 +978,45 @@ def _search_runs(
runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results)
return runs, next_page_token
- def log_metric(self, run_id, metric):
+ def log_metric(self, run_id: str, metric: Metric):
_validate_run_id(run_id)
_validate_metric(metric.key, metric.value, metric.timestamp, metric.step)
run_info = self._get_run_info(run_id)
check_run_is_active(run_info)
self._log_run_metric(run_info, metric)
+ if metric.model_id is not None:
+ self._log_model_metric(
+ experiment_id=run_info.experiment_id,
+ model_id=metric.model_id,
+ run_id=run_id,
+ metric=metric,
+ )
def _log_run_metric(self, run_info, metric):
metric_path = self._get_metric_path(run_info.experiment_id, run_info.run_id, metric.key)
make_containing_dirs(metric_path)
- append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n")
+ if metric.dataset_name is not None and metric.dataset_digest is not None:
+ append_to(
+ metric_path,
+ f"{metric.timestamp} {metric.value} {metric.step} {metric.dataset_name} "
+ f"{metric.dataset_digest}\n",
+ )
+ else:
+ append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n")
+
+ def _log_model_metric(self, experiment_id: str, model_id: str, run_id: str, metric: Metric):
+ metric_path = self._get_model_metric_path(
+ experiment_id=experiment_id, model_id=model_id, metric_key=metric.key
+ )
+ make_containing_dirs(metric_path)
+ if metric.dataset_name is not None and metric.dataset_digest is not None:
+ append_to(
+ metric_path,
+ f"{metric.timestamp} {metric.value} {metric.step} {run_id} {metric.dataset_name} "
+ f"{metric.dataset_digest}\n",
+ )
+ else:
+ append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step} {run_id}\n")
def _writeable_value(self, tag_value):
if tag_value is None:
@@ -1077,6 +1138,13 @@ def log_batch(self, run_id, metrics, params, tags):
self._log_run_param(run_info, param)
for metric in metrics:
self._log_run_metric(run_info, metric)
+ if metric.model_id is not None:
+ self._log_model_metric(
+ experiment_id=run_info.experiment_id,
+ model_id=metric.model_id,
+ run_id=run_id,
+ metric=metric,
+ )
for tag in tags:
# NB: If the tag run name value is set, update the run info to assure
# synchronization.
@@ -1112,14 +1180,21 @@ def record_logged_model(self, run_id, mlflow_model):
except Exception as e:
raise MlflowException(e, INTERNAL_ERROR)
- def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None):
+ def log_inputs(
+ self,
+ run_id: str,
+ datasets: Optional[List[DatasetInput]] = None,
+ models: Optional[List[ModelInput]] = None,
+ ):
"""
- Log inputs, such as datasets, to the specified run.
+ Log inputs, such as datasets and models, to the specified run.
Args:
run_id: String id for the run
datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log
as inputs to the run.
+ models: List of :py:class:`mlflow.entities.ModelInput` instances to log
+ as inputs to the run.
Returns:
None.
@@ -1128,13 +1203,13 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None)
run_info = self._get_run_info(run_id)
check_run_is_active(run_info)
- if datasets is None:
+ if datasets is None and models is None:
return
experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True)
run_dir = self._get_run_dir(run_info.experiment_id, run_id)
- for dataset_input in datasets:
+ for dataset_input in datasets or []:
dataset = dataset_input.dataset
dataset_id = FileStore._get_dataset_id(
dataset_name=dataset.name, dataset_digest=dataset.digest
@@ -1144,7 +1219,7 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None)
os.makedirs(dataset_dir, exist_ok=True)
write_yaml(dataset_dir, FileStore.META_DATA_FILE_NAME, dict(dataset))
- input_id = FileStore._get_input_id(dataset_id=dataset_id, run_id=run_id)
+ input_id = FileStore._get_dataset_input_id(dataset_id=dataset_id, run_id=run_id)
input_dir = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME, input_id)
if not os.path.exists(input_dir):
os.makedirs(input_dir, exist_ok=True)
@@ -1157,6 +1232,57 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None)
)
fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME)
+ for model_input in models or []:
+ model_id = model_input.model_id
+ input_id = FileStore._get_model_input_id(model_id=model_id, run_id=run_id)
+ input_dir = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME, input_id)
+ if not os.path.exists(input_dir):
+ os.makedirs(input_dir, exist_ok=True)
+ fs_input = FileStore._FileStoreInput(
+ source_type=InputVertexType.MODEL,
+ source_id=model_id,
+ destination_type=InputVertexType.RUN,
+ destination_id=run_id,
+ tags={},
+ )
+ fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME)
+
+ def log_outputs(self, run_id, models: Optional[List[ModelOutput]] = None):
+ """
+ Log outputs, such as models, to the specified run.
+
+ Args:
+ run_id: String id for the run
+ models: List of :py:class:`mlflow.entities.ModelOutput` instances to log
+ as outputs of the run.
+
+ Returns:
+ None.
+ """
+ _validate_run_id(run_id)
+ run_info = self._get_run_info(run_id)
+ check_run_is_active(run_info)
+
+ if models is None:
+ return
+
+ run_dir = self._get_run_dir(run_info.experiment_id, run_id)
+
+ for model_output in models:
+ model_id = model_output.model_id
+ output_dir = os.path.join(run_dir, FileStore.OUTPUTS_FOLDER_NAME, model_id)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ fs_output = FileStore._FileStoreOutput(
+ source_type=OutputVertexType.RUN_OUTPUT,
+ source_id=model_id,
+ destination_type=OutputVertexType.MODEL_OUTPUT,
+ destination_id=run_id,
+ tags={},
+ step=model_output.step,
+ )
+ fs_output.write_yaml(output_dir, FileStore.META_DATA_FILE_NAME)
+
@staticmethod
def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str:
md5 = insecure_hash.md5(dataset_name.encode("utf-8"))
@@ -1164,11 +1290,17 @@ def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str:
return md5.hexdigest()
@staticmethod
- def _get_input_id(dataset_id: str, run_id: str) -> str:
+ def _get_dataset_input_id(dataset_id: str, run_id: str) -> str:
md5 = insecure_hash.md5(dataset_id.encode("utf-8"))
md5.update(run_id.encode("utf-8"))
return md5.hexdigest()
+ @staticmethod
+ def _get_model_input_id(model_id: str, run_id: str) -> str:
+ md5 = insecure_hash.md5(model_id.encode("utf-8"))
+ md5.update(run_id.encode("utf-8"))
+ return md5.hexdigest()
+
class _FileStoreInput(NamedTuple):
source_type: int
source_id: str
@@ -1197,13 +1329,54 @@ def from_yaml(cls, root, file_name):
tags=dict_from_yaml["tags"],
)
+ class _FileStoreOutput(NamedTuple):
+ source_type: int
+ source_id: str
+ destination_type: int
+ destination_id: str
+ tags: Dict[str, str]
+ step: int
+
+ def write_yaml(self, root: str, file_name: str):
+ dict_for_yaml = {
+ "source_type": OutputVertexType.Name(self.source_type),
+ "source_id": self.source_id,
+ "destination_type": OutputVertexType.Name(self.destination_type),
+ "destination_id": self.source_id,
+ "tags": self.tags,
+ "step": self.step,
+ }
+ write_yaml(root, file_name, dict_for_yaml)
+
+ @classmethod
+ def from_yaml(cls, root, file_name):
+ dict_from_yaml = FileStore._read_yaml(root, file_name)
+ return cls(
+ source_type=OutputVertexType.Value(dict_from_yaml["source_type"]),
+ source_id=dict_from_yaml["source_id"],
+ destination_type=OutputVertexType.Value(dict_from_yaml["destination_type"]),
+ destination_id=dict_from_yaml["destination_id"],
+ tags=dict_from_yaml["tags"],
+ step=dict_from_yaml["step"],
+ )
+
def _get_all_inputs(self, run_info: RunInfo) -> RunInputs:
run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id)
inputs_parent_path = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME)
+ if not os.path.exists(inputs_parent_path):
+ return RunInputs(dataset_inputs=[], model_inputs=[])
+
experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True)
- datasets_parent_path = os.path.join(experiment_dir, FileStore.DATASETS_FOLDER_NAME)
- if not os.path.exists(inputs_parent_path) or not os.path.exists(datasets_parent_path):
- return RunInputs(dataset_inputs=[])
+ dataset_inputs = self._get_dataset_inputs(run_info, inputs_parent_path, experiment_dir)
+ model_inputs = self._get_model_inputs(inputs_parent_path, experiment_dir)
+ return RunInputs(dataset_inputs=dataset_inputs, model_inputs=model_inputs)
+
+ def _get_dataset_inputs(
+ self, run_info: RunInfo, inputs_parent_path: str, experiment_dir_path: str
+ ) -> List[DatasetInput]:
+ datasets_parent_path = os.path.join(experiment_dir_path, FileStore.DATASETS_FOLDER_NAME)
+ if not os.path.exists(datasets_parent_path):
+ return []
dataset_dirs = os.listdir(datasets_parent_path)
dataset_inputs = []
@@ -1213,9 +1386,6 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs:
input_dir_full_path, FileStore.META_DATA_FILE_NAME
)
if fs_input.source_type != InputVertexType.DATASET:
- logging.warning(
- f"Encountered invalid run input source type '{fs_input.source_type}'. Skipping."
- )
continue
matching_dataset_dirs = [d for d in dataset_dirs if d == fs_input.source_id]
@@ -1238,7 +1408,51 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs:
)
dataset_inputs.append(dataset_input)
- return RunInputs(dataset_inputs=dataset_inputs)
+ return dataset_inputs
+
+ def _get_model_inputs(
+ self, inputs_parent_path: str, experiment_dir_path: str
+ ) -> List[ModelInput]:
+ model_inputs = []
+ for input_dir in os.listdir(inputs_parent_path):
+ input_dir_full_path = os.path.join(inputs_parent_path, input_dir)
+ fs_input = FileStore._FileStoreInput.from_yaml(
+ input_dir_full_path, FileStore.META_DATA_FILE_NAME
+ )
+ if fs_input.source_type != InputVertexType.MODEL:
+ continue
+
+ model_input = ModelInput(model_id=fs_input.source_id)
+ model_inputs.append(model_input)
+
+ return model_inputs
+
+ def _get_all_outputs(self, run_info: RunInfo) -> RunOutputs:
+ run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id)
+ outputs_parent_path = os.path.join(run_dir, FileStore.OUTPUTS_FOLDER_NAME)
+ if not os.path.exists(outputs_parent_path):
+ return RunOutputs(model_outputs=[])
+
+ experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True)
+ model_outputs = self._get_model_outputs(outputs_parent_path, experiment_dir)
+ return RunOutputs(model_outputs=model_outputs)
+
+ def _get_model_outputs(
+ self, outputs_parent_path: str, experiment_dir: str
+ ) -> List[ModelOutput]:
+ model_outputs = []
+ for output_dir in os.listdir(outputs_parent_path):
+ output_dir_full_path = os.path.join(outputs_parent_path, output_dir)
+ fs_output = FileStore._FileStoreOutput.from_yaml(
+ output_dir_full_path, FileStore.META_DATA_FILE_NAME
+ )
+ if fs_output.destination_type != OutputVertexType.MODEL_OUTPUT:
+ continue
+
+ model_output = ModelOutput(model_id=fs_output.destination_id, step=fs_output.step)
+ model_outputs.append(model_output)
+
+ return model_outputs
def _search_datasets(self, experiment_ids) -> List[_DatasetSummary]:
"""
@@ -1686,3 +1900,294 @@ def _list_trace_infos(self, experiment_id):
exc_info=_logger.isEnabledFor(logging.DEBUG),
)
return trace_infos
+
+ def create_logged_model(
+ self,
+ experiment_id: str,
+ name: str,
+ run_id: Optional[str] = None,
+ tags: Optional[List[ModelTag]] = None,
+ params: Optional[List[ModelParam]] = None,
+ model_type: Optional[str] = None,
+ ) -> LoggedModel:
+ """
+ Create a new model.
+
+ Args:
+ experiment_id: ID of the Experiment where the model is being created.
+ name: Name of the model.
+ run_id: Run ID where the model is being created from.
+ tags: Key-value tags for the model.
+ params: Key-value params for the model.
+
+ Returns:
+ The model version.
+ """
+ experiment_id = FileStore.DEFAULT_EXPERIMENT_ID if experiment_id is None else experiment_id
+ experiment = self.get_experiment(experiment_id)
+ if experiment is None:
+ raise MlflowException(
+ "Could not create model under experiment with ID %s - no such experiment "
+ "exists." % experiment_id,
+ databricks_pb2.RESOURCE_DOES_NOT_EXIST,
+ )
+ if experiment.lifecycle_stage != LifecycleStage.ACTIVE:
+ raise MlflowException(
+ f"Could not create model under non-active experiment with ID {experiment_id}.",
+ databricks_pb2.INVALID_STATE,
+ )
+ for param in params or []:
+ _validate_param(param.key, param.value)
+
+ model_id = str(uuid.uuid4())
+ artifact_location = self._get_model_artifact_dir(experiment_id, model_id)
+ creation_timestamp = int(time.time() * 1000)
+ model = LoggedModel(
+ experiment_id=experiment_id,
+ model_id=model_id,
+ name=name,
+ artifact_location=artifact_location,
+ creation_timestamp=creation_timestamp,
+ last_updated_timestamp=creation_timestamp,
+ run_id=run_id,
+ status=ModelStatus.PENDING,
+ tags=tags,
+ params=params,
+ model_type=model_type,
+ )
+
+ # Persist model metadata and create directories for logging metrics, tags
+ model_dir = self._get_model_dir(experiment_id, model_id)
+ mkdir(model_dir)
+ model_info_dict: Dict[str, Any] = self._make_persisted_model_dict(model)
+ write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict)
+ mkdir(model_dir, FileStore.METRICS_FOLDER_NAME)
+ for tag in tags or []:
+ self.set_logged_model_tag(model_id=model_id, tag=tag)
+
+ return self.get_logged_model(model_id=model_id)
+
+ def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel:
+ """
+ Finalize a model by updating its status.
+
+ Args:
+ model_id: ID of the model to finalize.
+ status: Final status to set on the model.
+
+ Returns:
+ The updated model.
+ """
+ if status != ModelStatus.READY:
+ raise MlflowException(
+ f"Invalid model status: {status}. Expected statuses: [{ModelStatus.READY}]",
+ databricks_pb2.INVALID_PARAMETER_VALUE,
+ )
+ model_dict = self._get_model_dict(model_id)
+ model = LoggedModel.from_dictionary(model_dict)
+ model.status = status
+ model.last_updated_timestamp = int(time.time() * 1000)
+ model_dir = self._get_model_dir(model.experiment_id, model.model_id)
+ model_info_dict = self._make_persisted_model_dict(model)
+ write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict, overwrite=True)
+ return self.get_logged_model(model_id)
+
+ def set_logged_model_tag(self, model_id: str, tag: ModelTag):
+ _validate_tag_name(tag.key)
+ model = self.get_logged_model(model_id)
+ tag_path = os.path.join(
+ self._get_model_dir(model.experiment_id, model.model_id),
+ FileStore.TAGS_FOLDER_NAME,
+ tag.key,
+ )
+ make_containing_dirs(tag_path)
+ # Don't add trailing newline
+ write_to(tag_path, self._writeable_value(tag.value))
+
+ def get_logged_model(self, model_id: str) -> LoggedModel:
+ return LoggedModel.from_dictionary(self._get_model_dict(model_id))
+
+ def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str:
+ return append_to_uri_path(
+ self.get_experiment(experiment_id).artifact_location,
+ FileStore.MODELS_FOLDER_NAME,
+ model_id,
+ FileStore.ARTIFACTS_FOLDER_NAME,
+ )
+
+ def _make_persisted_model_dict(self, model: LoggedModel) -> Dict[str, Any]:
+ model_dict = model.to_dictionary()
+ model_dict.pop("tags", None)
+ model_dict.pop("metrics", None)
+ return model_dict
+
+ def _get_model_dict(self, model_id: str) -> Dict[str, Any]:
+ exp_id, model_dir = self._find_model_root(model_id)
+ if model_dir is None:
+ raise MlflowException(
+ f"Model '{model_id}' not found", databricks_pb2.RESOURCE_DOES_NOT_EXIST
+ )
+ model_dict: Dict[str, Any] = self._get_model_info_from_dir(model_dir)
+ if model_dict["experiment_id"] != exp_id:
+ raise MlflowException(
+ f"Model '{model_id}' metadata is in invalid state.", databricks_pb2.INVALID_STATE
+ )
+ return model_dict
+
+ def _get_model_dir(self, experiment_id: str, model_id: str) -> str:
+ if not self._has_experiment(experiment_id):
+ return None
+ return os.path.join(
+ self._get_experiment_path(experiment_id, assert_exists=True),
+ FileStore.MODELS_FOLDER_NAME,
+ model_id,
+ )
+
+ def _find_model_root(self, model_id):
+ self._check_root_dir()
+ all_experiments = self._get_active_experiments(False) + self._get_deleted_experiments(False)
+ for experiment_dir in all_experiments:
+ models_dir_path = os.path.join(
+ self.root_directory, experiment_dir, FileStore.MODELS_FOLDER_NAME
+ )
+ models = find(models_dir_path, model_id, full_path=True)
+ if len(models) == 0:
+ continue
+ return os.path.basename(os.path.dirname(os.path.abspath(models_dir_path))), models[0]
+ return None, None
+
+ def _get_model_from_dir(self, model_dir: str) -> LoggedModel:
+ return LoggedModel.from_dictionary(self._get_model_info_from_dir(model_dir))
+
+ def _get_model_info_from_dir(self, model_dir: str) -> Dict[str, Any]:
+ model_dict = FileStore._read_yaml(model_dir, FileStore.META_DATA_FILE_NAME)
+ model_dict["tags"] = self._get_all_model_tags(model_dir)
+ model_dict["metrics"] = self._get_all_model_metrics(
+ model_id=model_dict["model_id"], model_dir=model_dir
+ )
+ return model_dict
+
+ def _get_all_model_tags(self, model_dir: str) -> List[ModelTag]:
+ parent_path, tag_files = self._get_resource_files(model_dir, FileStore.TAGS_FOLDER_NAME)
+ tags = []
+ for tag_file in tag_files:
+ tags.append(self._get_tag_from_file(parent_path, tag_file))
+ return tags
+
+ def _get_all_model_metrics(self, model_id: str, model_dir: str) -> List[Metric]:
+ parent_path, metric_files = self._get_resource_files(
+ model_dir, FileStore.METRICS_FOLDER_NAME
+ )
+ metrics = []
+ for metric_file in metric_files:
+ metrics.extend(
+ FileStore._get_model_metrics_from_file(
+ model_id=model_id, parent_path=parent_path, metric_name=metric_file
+ )
+ )
+ return metrics
+
+ @staticmethod
+ def _get_model_metrics_from_file(
+ model_id: str, parent_path: str, metric_name: str
+ ) -> List[Metric]:
+ _validate_metric_name(metric_name)
+ metric_objs = [
+ FileStore._get_model_metric_from_line(model_id, metric_name, line)
+ for line in read_file_lines(parent_path, metric_name)
+ ]
+ if len(metric_objs) == 0:
+ raise ValueError(f"Metric '{metric_name}' is malformed. No data found.")
+
+ # Group metrics by (dataset_name, dataset_digest)
+ grouped_metrics = defaultdict(list)
+ for metric in metric_objs:
+ key = (metric.dataset_name, metric.dataset_digest)
+ grouped_metrics[key].append(metric)
+
+ # Compute the max for each group
+ return [
+ max(group, key=lambda m: (m.step, m.timestamp, m.value))
+ for group in grouped_metrics.values()
+ ]
+
+ @staticmethod
+ def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: str) -> Metric:
+ metric_parts = metric_line.strip().split(" ")
+ if len(metric_parts) not in [4, 6]:
+ raise MlflowException(
+ f"Metric '{metric_name}' is malformed; persisted metric data contained "
+ f"{len(metric_parts)} fields. Expected 4 or 6 fields.",
+ databricks_pb2.INTERNAL_ERROR,
+ )
+ ts = int(metric_parts[0])
+ val = float(metric_parts[1])
+ step = int(metric_parts[2])
+ run_id = str(metric_parts[3])
+ dataset_name = str(metric_parts[4]) if len(metric_parts) == 6 else None
+ dataset_digest = str(metric_parts[5]) if len(metric_parts) == 6 else None
+ # TODO: Read run ID from the metric file and pass it to the Metric constructor
+ return Metric(
+ key=metric_name,
+ value=val,
+ timestamp=ts,
+ step=step,
+ model_id=model_id,
+ dataset_name=dataset_name,
+ dataset_digest=dataset_digest,
+ run_id=run_id,
+ )
+
+ def search_logged_models(
+ self,
+ experiment_ids: List[str],
+ filter_string: Optional[str] = None,
+ max_results: Optional[int] = None,
+ order_by: Optional[List[str]] = None,
+ ) -> List[LoggedModel]:
+ all_models = []
+ for experiment_id in experiment_ids:
+ models = self._list_models(experiment_id)
+ all_models.extend(models)
+ filtered = SearchUtils.filter_logged_models(models, filter_string)
+ return SearchUtils.sort_logged_models(filtered, order_by)[:max_results]
+
+ def _list_models(self, experiment_id: str) -> List[LoggedModel]:
+ self._check_root_dir()
+ if not self._has_experiment(experiment_id):
+ return []
+ experiment_dir = self._get_experiment_path(experiment_id, assert_exists=True)
+ model_dirs = list_all(
+ os.path.join(experiment_dir, FileStore.MODELS_FOLDER_NAME),
+ filter_func=lambda x: all(
+ os.path.basename(os.path.normpath(x)) != reservedFolderName
+ for reservedFolderName in FileStore.RESERVED_EXPERIMENT_FOLDERS
+ )
+ and os.path.isdir(x),
+ full_path=True,
+ )
+ models = []
+ for m_dir in model_dirs:
+ try:
+ # trap and warn known issues, will raise unexpected exceptions to caller
+ model = self._get_model_from_dir(m_dir)
+ if model.experiment_id != experiment_id:
+ logging.warning(
+ "Wrong experiment ID (%s) recorded for model '%s'. "
+ "It should be %s. Model will be ignored.",
+ str(model.experiment_id),
+ str(model.model_id),
+ str(experiment_id),
+ exc_info=True,
+ )
+ continue
+ models.append(model)
+ except MissingConfigException as exc:
+ # trap malformed model exception and log
+ # this is at debug level because if the same store is used for
+ # artifact storage, it's common the folder is not a run folder
+ m_id = os.path.basename(m_dir)
+ logging.debug(
+ "Malformed model '%s'. Detailed error %s", m_id, str(exc), exc_info=True
+ )
+ return models
diff --git a/mlflow/tracing/constant.py b/mlflow/tracing/constant.py
index 6fd5c0c025178..68eaa47c20aa9 100644
--- a/mlflow/tracing/constant.py
+++ b/mlflow/tracing/constant.py
@@ -3,6 +3,7 @@ class TraceMetadataKey:
INPUTS = "mlflow.traceInputs"
OUTPUTS = "mlflow.traceOutputs"
SOURCE_RUN = "mlflow.sourceRun"
+ MODEL_ID = "mlflow.modelId"
class TraceTagKey:
@@ -19,6 +20,7 @@ class SpanAttributeKey:
OUTPUTS = "mlflow.spanOutputs"
SPAN_TYPE = "mlflow.spanType"
FUNCTION_NAME = "mlflow.spanFunctionName"
+ MODEL_ID = "mlflow.modelId"
# All storage backends are guaranteed to support key values up to 250 characters
diff --git a/mlflow/tracing/display/display_handler.py b/mlflow/tracing/display/display_handler.py
index f8bff89122048..5da9553d2b618 100644
--- a/mlflow/tracing/display/display_handler.py
+++ b/mlflow/tracing/display/display_handler.py
@@ -90,6 +90,10 @@ def get_mimebundle(self, traces: List[Trace]):
}
def display_traces(self, traces: List[Trace]):
+ # Temporarily disable rendering of traces in Databricks notebooks,
+ # since it doesnt' work with file-based storage
+ return
+
# This only works in Databricks notebooks
if not is_in_databricks_runtime():
return
diff --git a/mlflow/tracing/fluent.py b/mlflow/tracing/fluent.py
index c1c7d66306de8..dc0932acd05d1 100644
--- a/mlflow/tracing/fluent.py
+++ b/mlflow/tracing/fluent.py
@@ -58,6 +58,7 @@ def trace(
name: Optional[str] = None,
span_type: str = SpanType.UNKNOWN,
attributes: Optional[Dict[str, Any]] = None,
+ model_id: Optional[str] = None,
) -> Callable:
"""
A decorator that creates a new span for the decorated function.
@@ -135,7 +136,9 @@ class _WrappingContext:
def _wrapping_logic(fn, args, kwargs):
span_name = name or fn.__name__
- with start_span(name=span_name, span_type=span_type, attributes=attributes) as span:
+ with start_span(
+ name=span_name, span_type=span_type, attributes=attributes, model_id=model_id
+ ) as span:
span.set_attribute(SpanAttributeKey.FUNCTION_NAME, fn.__name__)
try:
span.set_inputs(capture_function_input_args(fn, args, kwargs))
@@ -184,6 +187,7 @@ def start_span(
name: str = "span",
span_type: Optional[str] = SpanType.UNKNOWN,
attributes: Optional[Dict[str, Any]] = None,
+ model_id: Optional[str] = None,
) -> Generator[LiveSpan, None, None]:
"""
Context manager to create a new span and start it as the current span in the context.
@@ -253,9 +257,11 @@ def start_span(
# Create a new MLflow span and register it to the in-memory trace manager
request_id = get_otel_attribute(otel_span, SpanAttributeKey.REQUEST_ID)
mlflow_span = create_mlflow_span(otel_span, request_id, span_type)
- mlflow_span.set_attributes(attributes or {})
+ attributes = dict(attributes) if attributes is not None else {}
+ if model_id is not None:
+ attributes[SpanAttributeKey.MODEL_ID] = model_id
+ mlflow_span.set_attributes(attributes)
InMemoryTraceManager.get_instance().register_span(mlflow_span)
-
except Exception as e:
_logger.warning(
f"Failed to start span: {e}. For full traceback, set logging level to debug.",
@@ -332,6 +338,7 @@ def search_traces(
max_results: Optional[int] = None,
order_by: Optional[List[str]] = None,
extract_fields: Optional[List[str]] = None,
+ model_id: Optional[str] = None,
) -> "pandas.DataFrame":
"""
Return traces that match the given list of search expressions within the experiments.
@@ -430,6 +437,7 @@ def pagination_wrapper_func(number_to_get, next_page_token):
filter_string=filter_string,
order_by=order_by,
page_token=next_page_token,
+ model_id=model_id,
)
results = get_results_from_paginated_fn(
diff --git a/mlflow/tracing/processor/mlflow.py b/mlflow/tracing/processor/mlflow.py
index bed4c48f275b4..21decbda90ffa 100644
--- a/mlflow/tracing/processor/mlflow.py
+++ b/mlflow/tracing/processor/mlflow.py
@@ -145,17 +145,21 @@ def on_end(self, span: OTelReadableSpan) -> None:
return
request_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID)
+ # TODO: We should remove the model ID from the span attributes
+ model_id = get_otel_attribute(span, SpanAttributeKey.MODEL_ID)
with self._trace_manager.get_trace(request_id) as trace:
if trace is None:
_logger.debug(f"Trace data with request ID {request_id} not found.")
return
- self._update_trace_info(trace, span)
+ self._update_trace_info(trace, span, model_id)
deduplicate_span_names_in_place(list(trace.span_dict.values()))
super().on_end(span)
- def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan):
+ def _update_trace_info(
+ self, trace: _Trace, root_span: OTelReadableSpan, model_id: Optional[str]
+ ):
"""Update the trace info with the final values from the root span."""
# The trace/span start time needs adjustment to exclude the latency of
# the backend API call. We already adjusted the span start time in the
@@ -173,6 +177,8 @@ def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan):
),
}
)
+ if model_id is not None:
+ trace.info.request_metadata[SpanAttributeKey.MODEL_ID] = model_id
def _truncate_metadata(self, value: Optional[str]) -> str:
"""Get truncated value of the attribute if it exceeds the maximum length."""
diff --git a/mlflow/tracking/_model_registry/client.py b/mlflow/tracking/_model_registry/client.py
index 5cbe391b7debe..d3773bc029077 100644
--- a/mlflow/tracking/_model_registry/client.py
+++ b/mlflow/tracking/_model_registry/client.py
@@ -4,6 +4,7 @@
exposed in the :py:mod:`mlflow.tracking` module.
"""
import logging
+from typing import Optional
from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
from mlflow.exceptions import MlflowException
@@ -188,6 +189,7 @@ def create_model_version(
description=None,
await_creation_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
local_model_path=None,
+ model_id: Optional[str] = None,
):
"""Create a new model version from given source.
@@ -202,6 +204,8 @@ def create_model_version(
await_creation_for: Number of seconds to wait for the model version to finish being
created and is in ``READY`` status. By default, the function
waits for five minutes. Specify 0 or None to skip waiting.
+ model_id: The ID of the model (from an Experiment) that is being promoted to a
+ registered model version, if applicable.
Returns:
Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
@@ -220,6 +224,7 @@ def create_model_version(
run_link,
description,
local_model_path=local_model_path,
+ model_id=model_id,
)
else:
# Fall back to calling create_model_version without
diff --git a/mlflow/tracking/_model_registry/fluent.py b/mlflow/tracking/_model_registry/fluent.py
index 50892ab28063c..8d5b66c9d2de2 100644
--- a/mlflow/tracking/_model_registry/fluent.py
+++ b/mlflow/tracking/_model_registry/fluent.py
@@ -4,6 +4,7 @@
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import ALREADY_EXISTS, RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
+from mlflow.store.artifact.utils.models import _parse_model_uri
from mlflow.store.model_registry import (
SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
@@ -75,7 +76,10 @@ def register_model(
Version: 1
"""
return _register_model(
- model_uri=model_uri, name=name, await_registration_for=await_registration_for, tags=tags
+ model_uri=model_uri,
+ name=name,
+ await_registration_for=await_registration_for,
+ tags=tags,
)
@@ -109,6 +113,7 @@ def _register_model(
source = RunsArtifactRepository.get_underlying_uri(model_uri)
(run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri)
+ parsed_model_uri = _parse_model_uri(model_uri)
create_version_response = client._create_model_version(
name=name,
source=source,
@@ -116,6 +121,7 @@ def _register_model(
tags=tags,
await_creation_for=await_registration_for,
local_model_path=local_model_path,
+ model_id=parsed_model_uri.model_id,
)
eprint(
f"Created version '{create_version_response.version}' of model "
diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py
index 70f5dd38681c0..f7a0c3f95e494 100644
--- a/mlflow/tracking/_tracking_service/client.py
+++ b/mlflow/tracking/_tracking_service/client.py
@@ -13,7 +13,13 @@
from mlflow.entities import (
ExperimentTag,
+ LoggedModel,
Metric,
+ ModelInput,
+ ModelOutput,
+ ModelParam,
+ ModelStatus,
+ ModelTag,
Param,
RunStatus,
RunTag,
@@ -32,6 +38,7 @@
MlflowTraceDataNotFound,
)
from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE, ErrorCode
+from mlflow.store.artifact.artifact_repo import ArtifactRepository
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.entities.paged_list import PagedList
from mlflow.store.tracking import (
@@ -296,7 +303,18 @@ def _search_traces(
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
order_by: Optional[List[str]] = None,
page_token: Optional[str] = None,
+ model_id: Optional[str] = None,
):
+ if model_id is not None:
+ if filter_string:
+ raise MlflowException(
+ message=(
+ "Cannot specify both `model_id` and `experiment_ids` or `filter_string`"
+ " in the search_traces call."
+ ),
+ error_code=INVALID_PARAMETER_VALUE,
+ )
+ filter_string = f"request_metadata.`mlflow.modelId` = '{model_id}'"
return self.store.search_traces(
experiment_ids=experiment_ids,
filter_string=filter_string,
@@ -312,6 +330,7 @@ def search_traces(
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
order_by: Optional[List[str]] = None,
page_token: Optional[str] = None,
+ model_id: Optional[str] = None,
) -> PagedList[Trace]:
def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]:
"""
@@ -343,6 +362,7 @@ def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]:
max_results=next_max_results,
order_by=order_by,
page_token=next_token,
+ model_id=model_id,
)
traces.extend(t for t in executor.map(download_trace_data, trace_infos) if t)
@@ -532,7 +552,16 @@ def rename_experiment(self, experiment_id, new_name):
self.store.rename_experiment(experiment_id, new_name)
def log_metric(
- self, run_id, key, value, timestamp=None, step=None, synchronous=True
+ self,
+ run_id,
+ key,
+ value,
+ timestamp=None,
+ step=None,
+ synchronous=True,
+ dataset_name: Optional[str] = None,
+ dataset_digest: Optional[str] = None,
+ model_id: Optional[str] = None,
) -> Optional[RunOperations]:
"""Log a metric against the run ID.
@@ -559,7 +588,15 @@ def log_metric(
timestamp = timestamp if timestamp is not None else get_current_time_millis()
step = step if step is not None else 0
metric_value = convert_metric_value_to_float_if_possible(value)
- metric = Metric(key, metric_value, timestamp, step)
+ metric = Metric(
+ key,
+ metric_value,
+ timestamp,
+ step,
+ model_id=model_id,
+ dataset_name=dataset_name,
+ dataset_digest=dataset_digest,
+ )
if synchronous:
self.store.log_metric(run_id, metric)
else:
@@ -698,10 +735,14 @@ def log_batch(
metrics = [
Metric(
- metric.key,
- convert_metric_value_to_float_if_possible(metric.value),
- metric.timestamp,
- metric.step,
+ key=metric.key,
+ value=convert_metric_value_to_float_if_possible(metric.value),
+ timestamp=metric.timestamp,
+ step=metric.step,
+ dataset_name=metric.dataset_name,
+ dataset_digest=metric.dataset_digest,
+ model_id=metric.model_id,
+ run_id=metric.run_id,
)
for metric in metrics
]
@@ -752,12 +793,18 @@ def log_batch(
# Merge all the run operations into a single run operations object
return get_combined_run_operations(run_operations_list)
- def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None):
+ def log_inputs(
+ self,
+ run_id: str,
+ datasets: Optional[List[DatasetInput]] = None,
+ models: Optional[List[ModelInput]] = None,
+ ):
"""Log one or more dataset inputs to a run.
Args:
run_id: String ID of the run
datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log.
+ models: List of :py:class:`mlflow.entities.ModelInput` instances to log.
Raises:
MlflowException: If any errors occur.
@@ -765,10 +812,10 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None)
Returns:
None
"""
- if datasets is None or len(datasets) == 0:
- return
+ self.store.log_inputs(run_id=run_id, datasets=datasets, models=models)
- self.store.log_inputs(run_id=run_id, datasets=datasets)
+ def log_outputs(self, run_id: str, models: List[ModelOutput]):
+ self.store.log_outputs(run_id=run_id, models=models)
def _record_logged_model(self, run_id, mlflow_model):
from mlflow.models import Model
@@ -976,3 +1023,64 @@ def search_runs(
order_by=order_by,
page_token=page_token,
)
+
+ def create_logged_model(
+ self,
+ experiment_id: str,
+ name: str,
+ run_id: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ params: Optional[Dict[str, str]] = None,
+ model_type: Optional[str] = None,
+ ) -> LoggedModel:
+ return self.store.create_logged_model(
+ experiment_id=experiment_id,
+ name=name,
+ run_id=run_id,
+ tags=[ModelTag(str(key), str(value)) for key, value in tags.items()]
+ if tags is not None
+ else tags,
+ params=[ModelParam(str(key), str(value)) for key, value in params.items()]
+ if params is not None
+ else params,
+ model_type=model_type,
+ )
+
+ def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel:
+ return self.store.finalize_logged_model(model_id, status)
+
+ def get_logged_model(self, model_id: str) -> LoggedModel:
+ return self.store.get_logged_model(model_id)
+
+ def set_logged_model_tag(self, model_id: str, key: str, value: str):
+ return self.store.set_logged_model_tag(model_id, ModelTag(key, value))
+
+ def log_model_artifacts(self, model_id: str, local_dir: str) -> None:
+ self._get_artifact_repo_for_logged_model(model_id).log_artifacts(local_dir)
+
+ def search_logged_models(
+ self,
+ experiment_ids: List[str],
+ filter_string: Optional[str] = None,
+ max_results: Optional[int] = None,
+ order_by: Optional[List[str]] = None,
+ ):
+ return self.store.search_logged_models(experiment_ids, filter_string, max_results, order_by)
+
+ def _get_artifact_repo_for_logged_model(self, model_id: str) -> ArtifactRepository:
+ # Attempt to fetch the artifact repo from a local cache
+ cached_repo = utils._artifact_repos_cache.get(model_id)
+ if cached_repo is not None:
+ return cached_repo
+ else:
+ model = self.get_logged_model(model_id)
+ artifact_uri = add_databricks_profile_info_to_artifact_uri(
+ model.artifact_location, self.tracking_uri
+ )
+ artifact_repo = get_artifact_repository(artifact_uri)
+ # Cache the artifact repo to avoid a future network call, removing the oldest
+ # entry in the cache if there are too many elements
+ if len(utils._artifact_repos_cache) > 1024:
+ utils._artifact_repos_cache.popitem(last=False)
+ utils._artifact_repos_cache[model_id] = artifact_repo
+ return artifact_repo
diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py
index f85e93755e296..60394810cb44f 100644
--- a/mlflow/tracking/client.py
+++ b/mlflow/tracking/client.py
@@ -23,7 +23,11 @@
DatasetInput,
Experiment,
FileInfo,
+ LoggedModel,
Metric,
+ ModelInput,
+ ModelOutput,
+ ModelStatus,
Param,
Run,
RunTag,
@@ -485,6 +489,7 @@ def search_traces(
max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS,
order_by: Optional[List[str]] = None,
page_token: Optional[str] = None,
+ model_id: Optional[str] = None,
) -> PagedList[Trace]:
"""
Return traces that match the given list of search expressions within the experiments.
@@ -511,6 +516,7 @@ def search_traces(
max_results=max_results,
order_by=order_by,
page_token=page_token,
+ model_id=model_id,
)
get_display_handler().display_traces(traces)
@@ -1440,6 +1446,9 @@ def log_metric(
timestamp: Optional[int] = None,
step: Optional[int] = None,
synchronous: Optional[bool] = None,
+ dataset_name: Optional[str] = None,
+ dataset_digest: Optional[str] = None,
+ model_id: Optional[str] = None,
) -> Optional[RunOperations]:
"""
Log a metric against the run ID.
@@ -1515,7 +1524,15 @@ def print_run_info(r):
synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
)
return self._tracking_client.log_metric(
- run_id, key, value, timestamp, step, synchronous=synchronous
+ run_id,
+ key,
+ value,
+ timestamp,
+ step,
+ synchronous=synchronous,
+ dataset_name=dataset_name,
+ dataset_digest=dataset_digest,
+ model_id=model_id,
)
def log_param(
@@ -1860,6 +1877,7 @@ def log_inputs(
self,
run_id: str,
datasets: Optional[Sequence[DatasetInput]] = None,
+ models: Optional[Sequence[ModelInput]] = None,
) -> None:
"""
Log one or more dataset inputs to a run.
@@ -1867,11 +1885,15 @@ def log_inputs(
Args:
run_id: String ID of the run.
datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log.
+ models: List of :py:class:`mlflow.entities.ModelInput` instances to log.
Raises:
mlflow.MlflowException: If any errors occur.
"""
- self._tracking_client.log_inputs(run_id, datasets)
+ self._tracking_client.log_inputs(run_id, datasets, models)
+
+ def log_outputs(self, run_id: str, models: Sequence[ModelOutput]):
+ self._tracking_client.log_outputs(run_id, models)
def log_artifact(self, run_id, local_path, artifact_path=None) -> None:
"""Write a local file or directory to the remote ``artifact_uri``.
@@ -3585,6 +3607,7 @@ def _create_model_version(
description: Optional[str] = None,
await_creation_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
local_model_path: Optional[str] = None,
+ model_id: Optional[str] = None,
) -> ModelVersion:
tracking_uri = self._tracking_client.tracking_uri
if (
@@ -3627,6 +3650,7 @@ def _create_model_version(
description=description,
await_creation_for=await_creation_for,
local_model_path=local_model_path,
+ model_id=model_id,
)
def create_model_version(
@@ -4735,3 +4759,39 @@ def print_model_version_info(mv):
"""
_validate_model_name(name)
return self._get_registry_client().get_model_version_by_alias(name, alias)
+
+ def create_logged_model(
+ self,
+ experiment_id: str,
+ name: str,
+ run_id: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ params: Optional[Dict[str, str]] = None,
+ model_type: Optional[str] = None,
+ ) -> LoggedModel:
+ return self._tracking_client.create_logged_model(
+ experiment_id, name, run_id, tags, params, model_type
+ )
+
+ def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel:
+ return self._tracking_client.finalize_logged_model(model_id, status)
+
+ def get_logged_model(self, model_id: str) -> LoggedModel:
+ return self._tracking_client.get_logged_model(model_id)
+
+ def set_logged_model_tag(self, model_id: str, key: str, value: str):
+ return self._tracking_client.set_logged_model_tag(model_id, key, value)
+
+ def log_model_artifacts(self, model_id: str, local_dir: str) -> None:
+ return self._tracking_client.log_model_artifacts(model_id, local_dir)
+
+ def search_logged_models(
+ self,
+ experiment_ids: List[str],
+ filter_string: Optional[str] = None,
+ max_results: Optional[int] = None,
+ order_by: Optional[List[str]] = None,
+ ):
+ return self._tracking_client.search_logged_models(
+ experiment_ids, filter_string, max_results, order_by
+ )
diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py
index 257627f316b62..4e4d3f5fc10a7 100644
--- a/mlflow/tracking/fluent.py
+++ b/mlflow/tracking/fluent.py
@@ -17,7 +17,9 @@
DatasetInput,
Experiment,
InputTag,
+ LoggedModel,
Metric,
+ ModelInput,
Param,
Run,
RunStatus,
@@ -825,6 +827,8 @@ def log_metric(
synchronous: Optional[bool] = None,
timestamp: Optional[int] = None,
run_id: Optional[str] = None,
+ model_id: Optional[str] = None,
+ dataset: Optional[Dataset] = None,
) -> Optional[RunOperations]:
"""
Log a metric under the current run. If no run is active, this method will create
@@ -868,14 +872,77 @@ def log_metric(
"""
run_id = run_id or _get_or_start_run().info.run_id
synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
- return MlflowClient().log_metric(
+ _log_inputs_for_metrics_if_necessary(
run_id,
- key,
- value,
- timestamp or get_current_time_millis(),
- step or 0,
- synchronous=synchronous,
+ [
+ Metric(
+ key=key,
+ value=value,
+ timestamp=timestamp or get_current_time_millis(),
+ step=step or 0,
+ model_id=model_id,
+ dataset_name=dataset.name if dataset is not None else None,
+ dataset_digest=dataset.digest if dataset is not None else None,
+ ),
+ ],
+ datasets=[dataset] if dataset is not None else None,
+ )
+ timestamp = timestamp or get_current_time_millis()
+ step = step or 0
+ model_ids = (
+ [model_id]
+ if model_id is not None
+ else (_get_model_ids_for_new_metric_if_exist(run_id, step) or [None])
)
+ for model_id in model_ids:
+ return MlflowClient().log_metric(
+ run_id,
+ key,
+ value,
+ timestamp,
+ step,
+ synchronous=synchronous,
+ model_id=model_id,
+ dataset_name=dataset.name if dataset is not None else None,
+ dataset_digest=dataset.digest if dataset is not None else None,
+ )
+
+
+def _log_inputs_for_metrics_if_necessary(
+ run_id, metrics: List[Metric], datasets: Optional[List[Dataset]] = None
+) -> None:
+ client = MlflowClient()
+ run = client.get_run(run_id)
+ datasets = datasets or []
+ for metric in metrics:
+ if metric.model_id is not None and metric.model_id not in [
+ inp.model_id for inp in run.inputs.model_inputs
+ ] + [output.model_id for output in run.outputs.model_outputs]:
+ client.log_inputs(run_id, models=[ModelInput(model_id=metric.model_id)])
+ if (metric.dataset_name, metric.dataset_digest) not in [
+ (inp.dataset.name, inp.dataset.digest) for inp in run.inputs.dataset_inputs
+ ]:
+ matching_dataset = next(
+ (
+ dataset
+ for dataset in datasets
+ if dataset.name == metric.dataset_name
+ and dataset.digest == metric.dataset_digest
+ ),
+ None,
+ )
+ if matching_dataset is not None:
+ client.log_inputs(
+ run_id,
+ datasets=[DatasetInput(matching_dataset._to_mlflow_entity(), tags=[])],
+ )
+
+
+def _get_model_ids_for_new_metric_if_exist(run_id: str, metric_step: str) -> List[str]:
+ client = MlflowClient()
+ run = client.get_run(run_id)
+ model_outputs_at_step = [mo for mo in run.outputs.model_outputs if mo.step == metric_step]
+ return [mo.model_id for mo in model_outputs_at_step]
def log_metrics(
@@ -884,6 +951,8 @@ def log_metrics(
synchronous: Optional[bool] = None,
run_id: Optional[str] = None,
timestamp: Optional[int] = None,
+ model_id: Optional[str] = None,
+ dataset: Optional[Dataset] = None,
) -> Optional[RunOperations]:
"""
Log multiple metrics for the current run. If no run is active, this method will create a new
@@ -925,10 +994,37 @@ def log_metrics(
"""
run_id = run_id or _get_or_start_run().info.run_id
timestamp = timestamp or get_current_time_millis()
- metrics_arr = [Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
+ step = step or 0
+ dataset_name = dataset.name if dataset is not None else None
+ dataset_digest = dataset.digest if dataset is not None else None
+ model_ids = (
+ [model_id]
+ if model_id is not None
+ else (_get_model_ids_for_new_metric_if_exist(run_id, step) or [None])
+ )
+ metrics_arr = [
+ Metric(
+ key,
+ value,
+ timestamp,
+ step or 0,
+ model_id=model_id,
+ dataset_name=dataset_name,
+ dataset_digest=dataset_digest,
+ )
+ for key, value in metrics.items()
+ for model_id in model_ids
+ ]
+ _log_inputs_for_metrics_if_necessary(
+ run_id, metrics_arr, [dataset] if dataset is not None else None
+ )
synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
return MlflowClient().log_batch(
- run_id=run_id, metrics=metrics_arr, params=[], tags=[], synchronous=synchronous
+ run_id=run_id,
+ metrics=metrics_arr,
+ params=[],
+ tags=[],
+ synchronous=synchronous,
)
@@ -1815,6 +1911,60 @@ def delete_experiment(experiment_id: str) -> None:
MlflowClient().delete_experiment(experiment_id)
+def create_logged_model(
+ name: str,
+ run_id: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ params: Optional[Dict[str, str]] = None,
+ model_type: Optional[str] = None,
+ experiment_id: Optional[str] = None,
+) -> LoggedModel:
+ run = active_run()
+ if run_id is None and run is not None:
+ run_id = run.info.run_id
+ experiment_id = experiment_id if experiment_id is not None else _get_experiment_id()
+ return MlflowClient().create_logged_model(
+ experiment_id=experiment_id,
+ name=name,
+ run_id=run_id,
+ tags=tags,
+ params=params,
+ model_type=model_type,
+ )
+
+
+def get_logged_model(model_id: str) -> LoggedModel:
+ return MlflowClient().get_logged_model(model_id)
+
+
+def search_logged_models(
+ experiment_ids: Optional[List[str]] = None,
+ filter_string: Optional[str] = None,
+ max_results: Optional[int] = None,
+ order_by: Optional[List[str]] = None,
+ output_format: str = "pandas",
+) -> Union[List[LoggedModel], "pandas.DataFrame"]:
+ experiment_ids = experiment_ids or [_get_experiment_id()]
+ models = MlflowClient().search_logged_models(
+ experiment_ids=experiment_ids,
+ filter_string=filter_string,
+ max_results=max_results,
+ order_by=order_by,
+ )
+ if output_format == "pandas":
+ import pandas as pd
+
+ return pd.DataFrame([model.to_dictionary() for model in models])
+ elif output_format == "list":
+ return models
+ else:
+ raise MlflowException(
+ "Unsupported output format: %s. Supported string values are 'pandas' or 'list'"
+ % output_format,
+ INVALID_PARAMETER_VALUE,
+ )
+
+
def delete_run(run_id: str) -> None:
"""
Deletes a run with the given ID.
diff --git a/mlflow/utils/search_utils.py b/mlflow/utils/search_utils.py
index aa399f93f0e99..0a983421cfc01 100644
--- a/mlflow/utils/search_utils.py
+++ b/mlflow/utils/search_utils.py
@@ -5,7 +5,7 @@
import operator
import re
import shlex
-from typing import Any, Dict
+from typing import Any, Dict, List, Optional
import sqlparse
from packaging.version import Version
@@ -20,7 +20,7 @@
)
from sqlparse.tokens import Token as TokenType
-from mlflow.entities import RunInfo
+from mlflow.entities import LoggedModel, RunInfo
from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
@@ -635,6 +635,38 @@ def _does_run_match_clause(cls, run, sed):
return SearchUtils.get_comparison_func(comparator)(lhs, value)
+ @classmethod
+ def _does_model_match_clause(cls, model, sed):
+ key_type = sed.get("type")
+ key = sed.get("key")
+ value = sed.get("value")
+ comparator = sed.get("comparator").upper()
+
+ key = SearchUtils.translate_key_alias(key)
+
+ if cls.is_metric(key_type, comparator):
+ matching_metrics = [metric for metric in model.metrics if metric.key == key]
+ lhs = matching_metrics[0].value if matching_metrics else None
+ value = float(value)
+ elif cls.is_param(key_type, comparator):
+ lhs = model.params.get(key, None)
+ elif cls.is_tag(key_type, comparator):
+ lhs = model.tags.get(key, None)
+ elif cls.is_string_attribute(key_type, key, comparator):
+ lhs = getattr(model.info, key)
+ elif cls.is_numeric_attribute(key_type, key, comparator):
+ lhs = getattr(model.info, key)
+ value = int(value)
+ else:
+ raise MlflowException(
+ f"Invalid model search expression type '{key_type}'",
+ error_code=INVALID_PARAMETER_VALUE,
+ )
+ if lhs is None:
+ return False
+
+ return SearchUtils.get_comparison_func(comparator)(lhs, value)
+
@classmethod
def filter(cls, runs, filter_string):
"""Filters a set of runs based on a search filter string."""
@@ -647,6 +679,20 @@ def run_matches(run):
return [run for run in runs if run_matches(run)]
+ @classmethod
+ def filter_logged_models(cls, models: List[LoggedModel], filter_string: Optional[str] = None):
+ """Filters a set of runs based on a search filter string."""
+ if not filter_string:
+ return models
+
+ # TODO: Update parsing function to handle model-specific filter clauses
+ parsed = cls.parse_search_filter(filter_string)
+
+ def model_matches(model):
+ return all(cls._does_model_match_clause(model, s) for s in parsed)
+
+ return [model for model in models if model_matches(model)]
+
@classmethod
def _validate_order_by_and_generate_token(cls, order_by):
try:
@@ -760,6 +806,40 @@ def _get_value_for_sort(cls, run, key_type, key, ascending):
return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value)
+ @classmethod
+ def _get_model_value_for_sort(cls, model, key_type, key, ascending):
+ """Returns a tuple suitable to be used as a sort key for models."""
+ sort_value = None
+ key = SearchUtils.translate_key_alias(key)
+ if key_type == cls._METRIC_IDENTIFIER:
+ matching_metrics = [metric for metric in model.metrics if metric.key == key]
+ sort_value = float(matching_metrics[0].value) if matching_metrics else None
+ elif key_type == cls._PARAM_IDENTIFIER:
+ sort_value = model.params.get(key)
+ elif key_type == cls._TAG_IDENTIFIER:
+ sort_value = model.tags.get(key)
+ elif key_type == cls._ATTRIBUTE_IDENTIFIER:
+ sort_value = getattr(model, key)
+ else:
+ raise MlflowException(
+ f"Invalid models order_by entity type '{key_type}'",
+ error_code=INVALID_PARAMETER_VALUE,
+ )
+
+ # Return a key such that None values are always at the end.
+ is_none = sort_value is None
+ is_nan = isinstance(sort_value, float) and math.isnan(sort_value)
+ fill_value = (1 if ascending else -1) * math.inf
+
+ if is_none:
+ sort_value = fill_value
+ elif is_nan:
+ sort_value = -fill_value
+
+ is_none_or_nan = is_none or is_nan
+
+ return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value)
+
@classmethod
def sort(cls, runs, order_by_list):
"""Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
@@ -780,6 +860,24 @@ def sort(cls, runs, order_by_list):
)
return runs
+ @classmethod
+ def sort_logged_models(cls, models, order_by_list):
+ models = sorted(models, key=lambda model: (-model.creation_timestamp, model.model_id))
+ if not order_by_list:
+ return models
+ # NB: We rely on the stability of Python's sort function, so that we can apply
+ # the ordering conditions in reverse order.
+ for order_by_clause in reversed(order_by_list):
+ # TODO: Update parsing function to handle model-specific order-by keys
+ (key_type, key, ascending) = cls.parse_order_by_for_search_runs(order_by_clause)
+
+ models = sorted(
+ models,
+ key=lambda model: cls._get_model_value_for_sort(model, key_type, key, ascending),
+ reverse=not ascending,
+ )
+ return models
+
@classmethod
def parse_start_offset_from_page_token(cls, page_token):
# Note: the page_token is expected to be a base64-encoded JSON that looks like