diff --git a/mlflow/__init__.py b/mlflow/__init__.py index 404fe895eb2a1..5382224912a24 100644 --- a/mlflow/__init__.py +++ b/mlflow/__init__.py @@ -126,6 +126,7 @@ active_run, autolog, create_experiment, + create_logged_model, delete_experiment, delete_run, delete_tag, @@ -136,6 +137,7 @@ get_artifact_uri, get_experiment, get_experiment_by_name, + get_logged_model, get_parent_run, get_run, last_active_run, @@ -153,6 +155,7 @@ log_table, log_text, search_experiments, + search_logged_models, search_runs, set_experiment, set_experiment_tag, @@ -173,6 +176,7 @@ "active_run", "autolog", "create_experiment", + "create_logged_model", "delete_experiment", "delete_run", "delete_tag", @@ -188,6 +192,7 @@ "get_experiment", "get_experiment_by_name", "get_last_active_trace", + "get_logged_model", "get_parent_run", "get_registry_uri", "get_run", @@ -212,6 +217,7 @@ "register_model", "run", "search_experiments", + "search_logged_models", "search_model_versions", "search_registered_models", "search_runs", diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py index 483d39835e95b..4c02c95f43fd8 100644 --- a/mlflow/entities/__init__.py +++ b/mlflow/entities/__init__.py @@ -11,12 +11,19 @@ from mlflow.entities.file_info import FileInfo from mlflow.entities.input_tag import InputTag from mlflow.entities.lifecycle_stage import LifecycleStage +from mlflow.entities.logged_model import LoggedModel from mlflow.entities.metric import Metric +from mlflow.entities.model_input import ModelInput +from mlflow.entities.model_output import ModelOutput +from mlflow.entities.model_param import ModelParam +from mlflow.entities.model_status import ModelStatus +from mlflow.entities.model_tag import ModelTag from mlflow.entities.param import Param from mlflow.entities.run import Run from mlflow.entities.run_data import RunData from mlflow.entities.run_info import RunInfo from mlflow.entities.run_inputs import RunInputs +from mlflow.entities.run_outputs import RunOutputs from mlflow.entities.run_status import RunStatus from mlflow.entities.run_tag import RunTag from mlflow.entities.source_type import SourceType @@ -46,6 +53,7 @@ "InputTag", "DatasetInput", "RunInputs", + "RunOutputs", "Span", "LiveSpan", "NoOpSpan", @@ -57,4 +65,10 @@ "TraceInfo", "SpanStatusCode", "_DatasetSummary", + "LoggedModel", + "ModelInput", + "ModelOutput", + "ModelStatus", + "ModelTag", + "ModelParam", ] diff --git a/mlflow/entities/logged_model.py b/mlflow/entities/logged_model.py new file mode 100644 index 0000000000000..b4fb1baa13757 --- /dev/null +++ b/mlflow/entities/logged_model.py @@ -0,0 +1,162 @@ +from typing import Any, Dict, List, Optional, Union + +from mlflow.entities._mlflow_object import _MlflowObject +from mlflow.entities.metric import Metric +from mlflow.entities.model_param import ModelParam +from mlflow.entities.model_status import ModelStatus +from mlflow.entities.model_tag import ModelTag + + +class LoggedModel(_MlflowObject): + """ + MLflow entity representing a Model logged to an MLflow Experiment. + """ + + def __init__( + self, + experiment_id: str, + model_id: str, + name: str, + artifact_location: str, + creation_timestamp: int, + last_updated_timestamp: int, + model_type: Optional[str] = None, + run_id: Optional[str] = None, + status: ModelStatus = ModelStatus.READY, + status_message: Optional[str] = None, + tags: Optional[Union[List[ModelTag], Dict[str, str]]] = None, + params: Optional[Union[List[ModelParam], Dict[str, str]]] = None, + metrics: Optional[List[Metric]] = None, + ): + super().__init__() + self._experiment_id: str = experiment_id + self._model_id: str = model_id + self._name: str = name + self._artifact_location: str = artifact_location + self._creation_time: int = creation_timestamp + self._last_updated_timestamp: int = last_updated_timestamp + self._model_type: Optional[str] = model_type + self._run_id: Optional[str] = run_id + self._status: ModelStatus = status + self._status_message: Optional[str] = status_message + self._tags: Dict[str, str] = ( + {tag.key: tag.value for tag in (tags or [])} if isinstance(tags, list) else (tags or {}) + ) + self._params: Dict[str, str] = ( + {param.key: param.value for param in (params or [])} + if isinstance(params, list) + else (params or {}) + ) + self._metrics: Optional[List[Metric]] = metrics + + @property + def experiment_id(self) -> str: + """String. Experiment ID associated with this Model.""" + return self._experiment_id + + @experiment_id.setter + def experiment_id(self, new_experiment_id: str): + self._experiment_id = new_experiment_id + + @property + def model_id(self) -> str: + """String. Unique ID for this Model.""" + return self._model_id + + @model_id.setter + def model_id(self, new_model_id: str): + self._model_id = new_model_id + + @property + def name(self) -> str: + """String. Name for this Model.""" + return self._name + + @name.setter + def name(self, new_name: str): + self._name = new_name + + @property + def artifact_location(self) -> str: + """String. Location of the model artifacts.""" + return self._artifact_location + + @artifact_location.setter + def artifact_location(self, new_artifact_location: str): + self._artifact_location = new_artifact_location + + @property + def creation_timestamp(self) -> int: + """Integer. Model creation timestamp (milliseconds since the Unix epoch).""" + return self._creation_time + + @property + def last_updated_timestamp(self) -> int: + """Integer. Timestamp of last update for this Model (milliseconds since the Unix + epoch). + """ + return self._last_updated_timestamp + + @last_updated_timestamp.setter + def last_updated_timestamp(self, updated_timestamp: int): + self._last_updated_timestamp = updated_timestamp + + @property + def model_type(self) -> Optional[str]: + """String. Type of the model.""" + return self._model_type + + @model_type.setter + def model_type(self, new_model_type: Optional[str]): + self._model_type = new_model_type + + @property + def run_id(self) -> Optional[str]: + """String. MLflow run ID that generated this model.""" + return self._run_id + + @property + def status(self) -> ModelStatus: + """String. Current status of this Model.""" + return self._status + + @status.setter + def status(self, updated_status: str): + self._status = updated_status + + @property + def status_message(self) -> Optional[str]: + """String. Descriptive message for error status conditions.""" + return self._status_message + + @property + def tags(self) -> Dict[str, str]: + """Dictionary of tag key (string) -> tag value for this Model.""" + return self._tags + + @property + def params(self) -> Dict[str, str]: + """Model parameters.""" + return self._params + + @property + def metrics(self) -> Optional[List[Metric]]: + """List of metrics associated with this Model.""" + return self._metrics + + @metrics.setter + def metrics(self, new_metrics: Optional[List[Metric]]): + self._metrics = new_metrics + + @classmethod + def _properties(cls) -> List[str]: + # aggregate with base class properties since cls.__dict__ does not do it automatically + return sorted(cls._get_properties_helper()) + + def _add_tag(self, tag): + self._tags[tag.key] = tag.value + + def to_dictionary(self) -> Dict[str, Any]: + model_dict = dict(self) + model_dict["status"] = str(self.status) + return model_dict diff --git a/mlflow/entities/metric.py b/mlflow/entities/metric.py index bea6926b95d1f..68bd68451a150 100644 --- a/mlflow/entities/metric.py +++ b/mlflow/entities/metric.py @@ -1,3 +1,5 @@ +from typing import Optional + from mlflow.entities._mlflow_object import _MlflowObject from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE @@ -10,11 +12,31 @@ class Metric(_MlflowObject): Metric object. """ - def __init__(self, key, value, timestamp, step): + def __init__( + self, + key, + value, + timestamp, + step, + model_id: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_digest: Optional[str] = None, + run_id: Optional[str] = None, + ): + if (dataset_name, dataset_digest).count(None) == 1: + raise MlflowException( + "Both dataset_name and dataset_digest must be provided if one is provided", + INVALID_PARAMETER_VALUE, + ) + self._key = key self._value = value self._timestamp = timestamp self._step = step + self._model_id = model_id + self._dataset_name = dataset_name + self._dataset_digest = dataset_digest + self._run_id = run_id @property def key(self): @@ -36,16 +58,38 @@ def step(self): """Integer metric step (x-coordinate).""" return self._step + @property + def model_id(self): + """ID of the Model associated with the metric.""" + return self._model_id + + @property + def dataset_name(self) -> Optional[str]: + """String. Name of the dataset associated with the metric.""" + return self._dataset_name + + @property + def dataset_digest(self) -> Optional[str]: + """String. Digest of the dataset associated with the metric.""" + return self._dataset_digest + + @property + def run_id(self) -> Optional[str]: + """String. Run ID associated with the metric.""" + return self._run_id + def to_proto(self): metric = ProtoMetric() metric.key = self.key metric.value = self.value metric.timestamp = self.timestamp metric.step = self.step + # TODO: Add model_id, dataset_name, dataset_digest, and run_id to the proto return metric @classmethod def from_proto(cls, proto): + # TODO: Add model_id, dataset_name, dataset_digest, and run_id to the proto return cls(proto.key, proto.value, proto.timestamp, proto.step) def __eq__(self, __o): @@ -69,6 +113,10 @@ def to_dictionary(self): "value": self.value, "timestamp": self.timestamp, "step": self.step, + "model_id": self.model_id, + "dataset_name": self.dataset_name, + "dataset_digest": self.dataset_digest, + "run_id": self._run_id, } @classmethod diff --git a/mlflow/entities/model_input.py b/mlflow/entities/model_input.py new file mode 100644 index 0000000000000..456c6db70ae5f --- /dev/null +++ b/mlflow/entities/model_input.py @@ -0,0 +1,18 @@ +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelInput(_MlflowObject): + """ModelInput object associated with a Run.""" + + def __init__(self, model_id: str): + self._model_id = model_id + + def __eq__(self, other: _MlflowObject) -> bool: + if type(other) is type(self): + return self.__dict__ == other.__dict__ + return False + + @property + def model_id(self) -> str: + """Model ID.""" + return self._model_id diff --git a/mlflow/entities/model_output.py b/mlflow/entities/model_output.py new file mode 100644 index 0000000000000..058a50b316271 --- /dev/null +++ b/mlflow/entities/model_output.py @@ -0,0 +1,24 @@ +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelOutput(_MlflowObject): + """ModelOutput object associated with a Run.""" + + def __init__(self, model_id: str, step: int) -> None: + self._model_id = model_id + self._step = step + + def __eq__(self, other: _MlflowObject) -> bool: + if type(other) is type(self): + return self.__dict__ == other.__dict__ + return False + + @property + def model_id(self) -> str: + """Model ID""" + return self._model_id + + @property + def step(self) -> str: + """Step at which the model was logged""" + return self._step diff --git a/mlflow/entities/model_param.py b/mlflow/entities/model_param.py new file mode 100644 index 0000000000000..b5e4cf7fe8c65 --- /dev/null +++ b/mlflow/entities/model_param.py @@ -0,0 +1,38 @@ +import sys + +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelParam(_MlflowObject): + """ + MLflow entity representing a parameter of a Model. + """ + + def __init__(self, key, value): + if "pyspark.ml" in sys.modules: + import pyspark.ml.param + + if isinstance(key, pyspark.ml.param.Param): + key = key.name + value = str(value) + self._key = key + self._value = value + + @property + def key(self): + """String key corresponding to the parameter name.""" + return self._key + + @property + def value(self): + """String value of the parameter.""" + return self._value + + def __eq__(self, __o): + if isinstance(__o, self.__class__): + return self._key == __o._key + + return False + + def __hash__(self): + return hash(self._key) diff --git a/mlflow/entities/model_registry/model_version.py b/mlflow/entities/model_registry/model_version.py index a30d022c9eb0f..e47f9c8deb1de 100644 --- a/mlflow/entities/model_registry/model_version.py +++ b/mlflow/entities/model_registry/model_version.py @@ -1,3 +1,7 @@ +from typing import Dict, List, Optional + +from mlflow.entities.metric import Metric +from mlflow.entities.model_param import ModelParam from mlflow.entities.model_registry._model_registry_entity import _ModelRegistryEntity from mlflow.entities.model_registry.model_version_status import ModelVersionStatus from mlflow.entities.model_registry.model_version_tag import ModelVersionTag @@ -12,140 +16,163 @@ class ModelVersion(_ModelRegistryEntity): def __init__( self, - name, - version, - creation_timestamp, - last_updated_timestamp=None, - description=None, - user_id=None, - current_stage=None, - source=None, - run_id=None, - status=ModelVersionStatus.to_string(ModelVersionStatus.READY), - status_message=None, - tags=None, - run_link=None, - aliases=None, + name: str, + version: str, + creation_timestamp: int, + last_updated_timestamp: Optional[int] = None, + description: Optional[str] = None, + user_id: Optional[str] = None, + current_stage: Optional[str] = None, + source: Optional[str] = None, + run_id: Optional[str] = None, + status: str = ModelVersionStatus.to_string(ModelVersionStatus.READY), + status_message: Optional[str] = None, + tags: Optional[List[ModelVersionTag]] = None, + run_link: Optional[str] = None, + aliases: Optional[List[str]] = None, + # TODO: Make model_id a required field + # (currently optional to minimize breakages during prototype development) + model_id: Optional[str] = None, + params: Optional[List[ModelParam]] = None, + metrics: Optional[List[Metric]] = None, ): super().__init__() - self._name = name - self._version = version - self._creation_time = creation_timestamp - self._last_updated_timestamp = last_updated_timestamp - self._description = description - self._user_id = user_id - self._current_stage = current_stage - self._source = source - self._run_id = run_id - self._run_link = run_link - self._status = status - self._status_message = status_message - self._tags = {tag.key: tag.value for tag in (tags or [])} - self._aliases = aliases or [] - - @property - def name(self): + self._name: str = name + self._version: str = version + self._creation_time: int = creation_timestamp + self._last_updated_timestamp: Optional[int] = last_updated_timestamp + self._description: Optional[str] = description + self._user_id: Optional[str] = user_id + self._current_stage: Optional[str] = current_stage + self._source: Optional[str] = source + self._run_id: Optional[str] = run_id + self._run_link: Optional[str] = run_link + self._status: str = status + self._status_message: Optional[str] = status_message + self._tags: Dict[str, str] = {tag.key: tag.value for tag in (tags or [])} + self._aliases: List[str] = aliases or [] + self._model_id: Optional[str] = model_id + self._params: Optional[List[ModelParam]] = params + self._metrics: Optional[List[Metric]] = metrics + + @property + def name(self) -> str: """String. Unique name within Model Registry.""" return self._name @name.setter - def name(self, new_name): + def name(self, new_name: str): self._name = new_name @property - def version(self): - """version""" + def version(self) -> str: + """Version""" return self._version @property - def creation_timestamp(self): + def creation_timestamp(self) -> int: """Integer. Model version creation timestamp (milliseconds since the Unix epoch).""" return self._creation_time @property - def last_updated_timestamp(self): + def last_updated_timestamp(self) -> Optional[int]: """Integer. Timestamp of last update for this model version (milliseconds since the Unix epoch). """ return self._last_updated_timestamp @last_updated_timestamp.setter - def last_updated_timestamp(self, updated_timestamp): + def last_updated_timestamp(self, updated_timestamp: int): self._last_updated_timestamp = updated_timestamp @property - def description(self): + def description(self) -> Optional[str]: """String. Description""" return self._description @description.setter - def description(self, description): + def description(self, description: str): self._description = description @property - def user_id(self): + def user_id(self) -> Optional[str]: """String. User ID that created this model version.""" return self._user_id @property - def current_stage(self): + def current_stage(self) -> Optional[str]: """String. Current stage of this model version.""" return self._current_stage @current_stage.setter - def current_stage(self, stage): + def current_stage(self, stage: str): self._current_stage = stage @property - def source(self): + def source(self) -> Optional[str]: """String. Source path for the model.""" return self._source @property - def run_id(self): + def run_id(self) -> Optional[str]: """String. MLflow run ID that generated this model.""" return self._run_id @property - def run_link(self): + def run_link(self) -> Optional[str]: """String. MLflow run link referring to the exact run that generated this model version.""" return self._run_link @property - def status(self): + def status(self) -> str: """String. Current Model Registry status for this model.""" return self._status @property - def status_message(self): + def status_message(self) -> Optional[str]: """String. Descriptive message for error status conditions.""" return self._status_message @property - def tags(self): + def tags(self) -> Dict[str, str]: """Dictionary of tag key (string) -> tag value for the current model version.""" return self._tags @property - def aliases(self): + def aliases(self) -> List[str]: """List of aliases (string) for the current model version.""" return self._aliases @aliases.setter - def aliases(self, aliases): + def aliases(self, aliases: List[str]): self._aliases = aliases + @property + def model_id(self) -> Optional[str]: + """String. ID of the model associated with this version.""" + return self._model_id + + @property + def params(self) -> Optional[List[ModelParam]]: + """List of parameters associated with this model version.""" + return self._params + + @property + def metrics(self) -> Optional[List[Metric]]: + """List of metrics associated with this model version.""" + return self._metrics + @classmethod - def _properties(cls): + def _properties(cls) -> List[str]: # aggregate with base class properties since cls.__dict__ does not do it automatically return sorted(cls._get_properties_helper()) - def _add_tag(self, tag): + def _add_tag(self, tag: ModelVersionTag): self._tags[tag.key] = tag.value # proto mappers @classmethod - def from_proto(cls, proto): + def from_proto(cls, proto: ProtoModelVersion) -> "ModelVersion": # input: mlflow.protos.model_registry_pb2.ModelVersion # returns: ModelVersion entity model_version = cls( @@ -165,9 +192,10 @@ def from_proto(cls, proto): ) for tag in proto.tags: model_version._add_tag(ModelVersionTag.from_proto(tag)) + # TODO: Include params, metrics, and model ID in proto return model_version - def to_proto(self): + def to_proto(self) -> ProtoModelVersion: # input: ModelVersion entity # returns mlflow.protos.model_registry_pb2.ModelVersion model_version = ProtoModelVersion() @@ -196,4 +224,5 @@ def to_proto(self): [ProtoModelVersionTag(key=key, value=value) for key, value in self._tags.items()] ) model_version.aliases.extend(self.aliases) + # TODO: Include params, metrics, and model ID in proto return model_version diff --git a/mlflow/entities/model_status.py b/mlflow/entities/model_status.py new file mode 100644 index 0000000000000..495eb0638022a --- /dev/null +++ b/mlflow/entities/model_status.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class ModelStatus(str, Enum): + """Enum for status of an :py:class:`mlflow.entities.Model`.""" + + PENDING = "PENDING" + READY = "READY" + FAILED = "FAILED" diff --git a/mlflow/entities/model_tag.py b/mlflow/entities/model_tag.py new file mode 100644 index 0000000000000..0774ce27759b1 --- /dev/null +++ b/mlflow/entities/model_tag.py @@ -0,0 +1,25 @@ +from mlflow.entities._mlflow_object import _MlflowObject + + +class ModelTag(_MlflowObject): + """Tag object associated with a Model.""" + + def __init__(self, key, value): + self._key = key + self._value = value + + def __eq__(self, other): + if type(other) is type(self): + # TODO deep equality here? + return self.__dict__ == other.__dict__ + return False + + @property + def key(self): + """String name of the tag.""" + return self._key + + @property + def value(self): + """String value of the tag.""" + return self._value diff --git a/mlflow/entities/run.py b/mlflow/entities/run.py index 0fade6fa8daa2..8bdc68fe8d895 100644 --- a/mlflow/entities/run.py +++ b/mlflow/entities/run.py @@ -4,6 +4,7 @@ from mlflow.entities.run_data import RunData from mlflow.entities.run_info import RunInfo from mlflow.entities.run_inputs import RunInputs +from mlflow.entities.run_outputs import RunOutputs from mlflow.exceptions import MlflowException from mlflow.protos.service_pb2 import Run as ProtoRun @@ -14,13 +15,18 @@ class Run(_MlflowObject): """ def __init__( - self, run_info: RunInfo, run_data: RunData, run_inputs: Optional[RunInputs] = None + self, + run_info: RunInfo, + run_data: RunData, + run_inputs: Optional[RunInputs] = None, + run_outputs: Optional[RunOutputs] = None, ) -> None: if run_info is None: raise MlflowException("run_info cannot be None") self._info = run_info self._data = run_data self._inputs = run_inputs + self._outputs = run_outputs @property def info(self) -> RunInfo: @@ -43,12 +49,21 @@ def data(self) -> RunData: @property def inputs(self) -> RunInputs: """ - The run inputs, including dataset inputs + The run inputs, including dataset inputs. :rtype: :py:class:`mlflow.entities.RunInputs` """ return self._inputs + @property + def outputs(self) -> RunOutputs: + """ + The run outputs, including model outputs. + + :rtype: :py:class:`mlflow.entities.RunOutputs` + """ + return self._outputs + def to_proto(self): run = ProtoRun() run.info.MergeFrom(self.info.to_proto()) @@ -56,6 +71,9 @@ def to_proto(self): run.data.MergeFrom(self.data.to_proto()) if self.inputs: run.inputs.MergeFrom(self.inputs.to_proto()) + # TODO: Support proto conversion for RunOutputs + # if self.outputs: + # run.outputs.MergeFrom(self.outputs.to_proto()) return run @classmethod @@ -63,7 +81,9 @@ def from_proto(cls, proto): return cls( RunInfo.from_proto(proto.info), RunData.from_proto(proto.data), - RunInputs.from_proto(proto.inputs), + RunInputs.from_proto(proto.inputs) if proto.inputs else None, + # TODO: Support proto conversion for RunOutputs + # RunOutputs.from_proto(proto.outputs) if proto.outputs else None, ) def to_dictionary(self) -> Dict[Any, Any]: @@ -74,4 +94,6 @@ def to_dictionary(self) -> Dict[Any, Any]: run_dict["data"] = self.data.to_dictionary() if self.inputs: run_dict["inputs"] = self.inputs.to_dictionary() + if self.outputs: + run_dict["outputs"] = self.outputs.to_dictionary() return run_dict diff --git a/mlflow/entities/run_inputs.py b/mlflow/entities/run_inputs.py index d28f026c71bc3..e5b8f1ccc7c7f 100644 --- a/mlflow/entities/run_inputs.py +++ b/mlflow/entities/run_inputs.py @@ -2,14 +2,16 @@ from mlflow.entities._mlflow_object import _MlflowObject from mlflow.entities.dataset_input import DatasetInput +from mlflow.entities.model_input import ModelInput from mlflow.protos.service_pb2 import RunInputs as ProtoRunInputs class RunInputs(_MlflowObject): """RunInputs object.""" - def __init__(self, dataset_inputs: List[DatasetInput]) -> None: + def __init__(self, dataset_inputs: List[DatasetInput], model_inputs: List[ModelInput]) -> None: self._dataset_inputs = dataset_inputs + self._model_inputs = model_inputs def __eq__(self, other: _MlflowObject) -> bool: if type(other) is type(self): @@ -21,16 +23,26 @@ def dataset_inputs(self) -> List[DatasetInput]: """Array of dataset inputs.""" return self._dataset_inputs + @property + def model_inputs(self) -> List[ModelInput]: + """Array of model inputs.""" + return self._model_inputs + def to_proto(self): run_inputs = ProtoRunInputs() run_inputs.dataset_inputs.extend( [dataset_input.to_proto() for dataset_input in self.dataset_inputs] ) + # TODO: Support proto conversion for model inputs + # run_inputs.model_inputs.extend( + # [model_input.to_proto() for model_input in self.model_inputs] + # ) return run_inputs def to_dictionary(self) -> Dict[Any, Any]: return { "dataset_inputs": self.dataset_inputs, + "model_inputs": self.model_inputs, } @classmethod @@ -38,4 +50,8 @@ def from_proto(cls, proto): dataset_inputs = [ DatasetInput.from_proto(dataset_input) for dataset_input in proto.dataset_inputs ] - return cls(dataset_inputs) + # TODO: Support proto conversion for model inputs + # model_inputs = [ + # ModelInput.from_proto(model_input) for model_input in proto.model_inputs + # ] + return cls(dataset_inputs, []) diff --git a/mlflow/entities/run_outputs.py b/mlflow/entities/run_outputs.py new file mode 100644 index 0000000000000..3d4b8a5f83b77 --- /dev/null +++ b/mlflow/entities/run_outputs.py @@ -0,0 +1,26 @@ +from typing import Any, Dict, List + +from mlflow.entities._mlflow_object import _MlflowObject +from mlflow.entities.model_output import ModelOutput + + +class RunOutputs(_MlflowObject): + """RunOutputs object.""" + + def __init__(self, model_outputs: List[ModelOutput]) -> None: + self._model_outputs = model_outputs + + def __eq__(self, other: _MlflowObject) -> bool: + if type(other) is type(self): + return self.__dict__ == other.__dict__ + return False + + @property + def model_outputs(self) -> List[ModelOutput]: + """Array of model outputs.""" + return self._model_outputs + + def to_dictionary(self) -> Dict[Any, Any]: + return { + "model_outputs": self.model_outputs, + } diff --git a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java index 487c8f9a2864e..4139fe8cab66d 100644 --- a/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java +++ b/mlflow/java/client/src/main/java/org/mlflow/internal/proto/Internal.java @@ -32,6 +32,10 @@ public enum InputVertexType * DATASET = 2; */ DATASET(2), + /** + * MODEL = 3; + */ + MODEL(3), ; /** @@ -42,6 +46,10 @@ public enum InputVertexType * DATASET = 2; */ public static final int DATASET_VALUE = 2; + /** + * MODEL = 3; + */ + public static final int MODEL_VALUE = 3; public final int getNumber() { @@ -66,6 +74,7 @@ public static InputVertexType forNumber(int value) { switch (value) { case 1: return RUN; case 2: return DATASET; + case 3: return MODEL; default: return null; } } @@ -115,6 +124,107 @@ private InputVertexType(int value) { // @@protoc_insertion_point(enum_scope:mlflow.internal.InputVertexType) } + /** + *
+   * Types of vertices represented in MLflow Run Outputs. Valid vertices are MLflow objects that can
+   * have an output relationship.
+   * 
+ * + * Protobuf enum {@code mlflow.internal.OutputVertexType} + */ + public enum OutputVertexType + implements com.google.protobuf.ProtocolMessageEnum { + /** + * RUN_OUTPUT = 1; + */ + RUN_OUTPUT(1), + /** + * MODEL_OUTPUT = 2; + */ + MODEL_OUTPUT(2), + ; + + /** + * RUN_OUTPUT = 1; + */ + public static final int RUN_OUTPUT_VALUE = 1; + /** + * MODEL_OUTPUT = 2; + */ + public static final int MODEL_OUTPUT_VALUE = 2; + + + public final int getNumber() { + return value; + } + + /** + * @param value The numeric wire value of the corresponding enum entry. + * @return The enum associated with the given numeric wire value. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static OutputVertexType valueOf(int value) { + return forNumber(value); + } + + /** + * @param value The numeric wire value of the corresponding enum entry. + * @return The enum associated with the given numeric wire value. + */ + public static OutputVertexType forNumber(int value) { + switch (value) { + case 1: return RUN_OUTPUT; + case 2: return MODEL_OUTPUT; + default: return null; + } + } + + public static com.google.protobuf.Internal.EnumLiteMap + internalGetValueMap() { + return internalValueMap; + } + private static final com.google.protobuf.Internal.EnumLiteMap< + OutputVertexType> internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public OutputVertexType findValueByNumber(int number) { + return OutputVertexType.forNumber(number); + } + }; + + public final com.google.protobuf.Descriptors.EnumValueDescriptor + getValueDescriptor() { + return getDescriptor().getValues().get(ordinal()); + } + public final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptorForType() { + return getDescriptor(); + } + public static final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptor() { + return org.mlflow.internal.proto.Internal.getDescriptor().getEnumTypes().get(1); + } + + private static final OutputVertexType[] VALUES = values(); + + public static OutputVertexType valueOf( + com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException( + "EnumValueDescriptor is not for this type."); + } + return VALUES[desc.getIndex()]; + } + + private final int value; + + private OutputVertexType(int value) { + this.value = value; + } + + // @@protoc_insertion_point(enum_scope:mlflow.internal.OutputVertexType) + } + public static com.google.protobuf.Descriptors.FileDescriptor getDescriptor() { @@ -125,9 +235,10 @@ private InputVertexType(int value) { static { java.lang.String[] descriptorData = { "\n\016internal.proto\022\017mlflow.internal\032\025scala" + - "pb/scalapb.proto*\'\n\017InputVertexType\022\007\n\003R" + - "UN\020\001\022\013\n\007DATASET\020\002B#\n\031org.mlflow.internal" + - ".proto\220\001\001\342?\002\020\001" + "pb/scalapb.proto*2\n\017InputVertexType\022\007\n\003R" + + "UN\020\001\022\013\n\007DATASET\020\002\022\t\n\005MODEL\020\003*4\n\020OutputVe" + + "rtexType\022\016\n\nRUN_OUTPUT\020\001\022\020\n\014MODEL_OUTPUT" + + "\020\002B#\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index 688bba77b2301..23045b1ce97b1 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -406,7 +406,7 @@ def load_retriever(persist_directory): @trace_disabled # Suppress traces for internal predict calls while logging model def log_model( lc_model, - artifact_path, + name: Optional[str] = None, conda_env=None, code_paths=None, registered_model_name=None, @@ -422,6 +422,11 @@ def log_model( run_id=None, model_config=None, streamable=None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, + model_id: Optional[str] = None, ): """ Log a LangChain model as an MLflow artifact for the current run. @@ -437,7 +442,7 @@ def log_model( .. Note:: Experimental: Using model as path may change or be removed in a future release without warning. - artifact_path: Run-relative artifact path. + name: The name of the model. conda_env: {{ conda_env }} code_paths: {{ code_paths }} registered_model_name: This argument may change or be removed in a @@ -547,7 +552,7 @@ def load_retriever(persist_directory): metadata of the logged model. """ return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.langchain, registered_model_name=registered_model_name, lc_model=lc_model, @@ -565,6 +570,11 @@ def load_retriever(persist_directory): run_id=run_id, model_config=model_config, streamable=streamable, + params=params, + tags=tags, + model_type=model_type, + step=step, + model_id=model_id, ) diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 170e89f5f2a57..2f9d807892ae5 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -14,6 +14,7 @@ import mlflow from mlflow.artifacts import download_artifacts +from mlflow.entities import Metric, ModelOutput, ModelStatus from mlflow.exceptions import MlflowException from mlflow.models.resources import Resource, ResourceType, _ResourceBuilder from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST @@ -654,13 +655,18 @@ def from_dict(cls, model_dict): @classmethod def log( cls, - artifact_path, + name, flavor, registered_model_name=None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, metadata=None, run_id=None, resources=None, + model_type: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + step: int = 0, + model_id: Optional[str] = None, **kwargs, ): """ @@ -668,7 +674,7 @@ def log( active run. Args: - artifact_path: Run relative path identifying the model. + name: The name of the model. flavor: Flavor module to save the model with. The module must have the ``save_model`` function that will persist the model as a valid MLflow model. @@ -689,15 +695,79 @@ def log( A :py:class:`ModelInfo ` instance that contains the metadata of the logged model. """ - from mlflow.utils.model_utils import _validate_and_get_model_config_from_file + if (model_id, name).count(None) == 2: + raise MlflowException( + "Either `model_id` or `name` must be specified when logging a model. " + "Both are None.", + error_code=INVALID_PARAMETER_VALUE, + ) + + def log_model_metrics_for_step(client, model_id, run_id, step): + metric_names = client.get_run(run_id).data.metrics.keys() + metrics_for_step = [] + for metric_name in metric_names: + history = client.get_metric_history(run_id, metric_name) + metrics_for_step.extend( + [ + Metric( + key=metric.key, + value=metric.value, + timestamp=metric.timestamp, + step=metric.step, + dataset_name=metric.dataset_name, + dataset_digest=metric.dataset_digest, + run_id=metric.run_id, + model_id=model_id, + ) + for metric in history + if metric.step == step and metric.model_id is None + ] + ) + client.log_batch(run_id=run_id, metrics=metrics_for_step) registered_model = None with TempDir() as tmp: local_path = tmp.path("model") - if run_id is None: - run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id + + tracking_uri = _resolve_tracking_uri() + client = mlflow.MlflowClient(tracking_uri) + active_run = mlflow.tracking.fluent.active_run() + if model_id is not None: + model = client.get_logged_model(model_id) + else: + params = { + **(params or {}), + **( + client.get_run(active_run.info.run_id).data.params + if active_run is not None + else {} + ), + } + model = client.create_logged_model( + experiment_id=mlflow.tracking.fluent._get_experiment_id(), + # TODO: Update model name + name=name, + run_id=active_run.info.run_id if active_run is not None else None, + model_type=model_type, + params={key: str(value) for key, value in params.items()}, + tags={key: str(value) for key, value in tags.items()} + if tags is not None + else None, + ) + + if active_run is not None: + run_id = active_run.info.run_id + client.log_outputs(run_id=run_id, models=[ModelOutput(model.model_id, step=step)]) + log_model_metrics_for_step( + client=client, model_id=model.model_id, run_id=run_id, step=step + ) + mlflow_model = cls( - artifact_path=artifact_path, run_id=run_id, metadata=metadata, resources=resources + artifact_path=model.artifact_location, + model_uuid=model.model_id, + run_id=active_run.info.run_id if active_run is not None else None, + metadata=metadata, + resources=resources, ) flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs) # `save_model` calls `load_model` to infer the model requirements, which may result in @@ -708,7 +778,6 @@ def log( if is_in_databricks_runtime(): _copy_model_metadata_for_uc_sharing(local_path, flavor) - tracking_uri = _resolve_tracking_uri() serving_input = mlflow_model.get_serving_input(local_path) # We check signature presence here as some flavors have a default signature as a # fallback when not provided by user, which is set during flavor's save_model() call. @@ -717,66 +786,60 @@ def log( _logger.warning(_LOG_MODEL_MISSING_INPUT_EXAMPLE_WARNING) elif tracking_uri == "databricks" or get_uri_scheme(tracking_uri) == "databricks": _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING) - mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id) - - # if the model_config kwarg is passed in, then log the model config as an params - if model_config := kwargs.get("model_config"): - if isinstance(model_config, str): - try: - file_extension = os.path.splitext(model_config)[1].lower() - if file_extension == ".json": - with open(model_config) as f: - model_config = json.load(f) - elif file_extension in [".yaml", ".yml"]: - model_config = _validate_and_get_model_config_from_file(model_config) - else: - _logger.warning( - "Unsupported file format for model config: %s. " - "Failed to load model config.", - model_config, - ) - except Exception as e: - _logger.warning("Failed to load model config from %s: %s", model_config, e) - try: - from mlflow.models.utils import _flatten_nested_params - - # We are using the `/` separator to flatten the nested params - # since we are using the same separator to log nested metrics. - params_to_log = _flatten_nested_params(model_config, sep="/") - except Exception as e: - _logger.warning("Failed to flatten nested params: %s", str(e)) - params_to_log = model_config - - try: - mlflow.tracking.fluent.log_params(params_to_log or {}, run_id=run_id) - except Exception as e: - _logger.warning("Failed to log model config as params: %s", str(e)) - - try: - mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id) - except MlflowException: - # We need to swallow all mlflow exceptions to maintain backwards compatibility with - # older tracking servers. Only print out a warning for now. - _logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri()) - _logger.debug("", exc_info=True) - - if registered_model_name is not None: - registered_model = mlflow.tracking._model_registry.fluent._register_model( - f"runs:/{run_id}/{mlflow_model.artifact_path}", - registered_model_name, - await_registration_for=await_registration_for, - local_model_path=local_path, - ) - model_info = mlflow_model.get_model_info() - if registered_model is not None: - model_info.registered_model_version = registered_model.version + client.log_model_artifacts(model.model_id, local_path) + client.finalize_logged_model(model.model_id, status=ModelStatus.READY) + + # # if the model_config kwarg is passed in, then log the model config as an params + # if model_config := kwargs.get("model_config"): + # if isinstance(model_config, str): + # try: + # file_extension = os.path.splitext(model_config)[1].lower() + # if file_extension == ".json": + # with open(model_config) as f: + # model_config = json.load(f) + # elif file_extension in [".yaml", ".yml"]: + # model_config = _validate_and_get_model_config_from_file(model_config) + # else: + # _logger.warning( + # "Unsupported file format for model config: %s. " + # "Failed to load model config.", + # model_config, + # ) + # except Exception as e: + # _logger.warning( + # "Failed to load model config from %s: %s", model_config, e + # ) + # + # try: + # from mlflow.models.utils import _flatten_nested_params + # + # # We are using the `/` separator to flatten the nested params + # # since we are using the same separator to log nested metrics. + # params_to_log = _flatten_nested_params(model_config, sep="/") + # except Exception as e: + # _logger.warning("Failed to flatten nested params: %s", str(e)) + # params_to_log = model_config + # + # try: + # mlflow.tracking.fluent.log_params(params_to_log or {}, run_id=run_id) + # except Exception as e: + # _logger.warning("Failed to log model config as params: %s", str(e)) + # + # try: + # mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id) + # except MlflowException: + # # We need to swallow all mlflow exceptions to maintain backwards compatibility + # # with older tracking servers. Only print out a warning for now. + # _logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri()) + # _logger.debug("", exc_info=True) # validate input example works for serving when logging the model if serving_input: from mlflow.models import validate_serving_input try: + model_info = mlflow_model.get_model_info() validate_serving_input(model_info.model_uri, serving_input) except Exception as e: _logger.warning( @@ -792,7 +855,21 @@ def log( exc_info=_logger.isEnabledFor(logging.DEBUG), ) - return model_info + if registered_model_name is not None: + registered_model = mlflow.tracking._model_registry.fluent._register_model( + f"models:/{model.model_id}", + registered_model_name, + await_registration_for=await_registration_for, + local_model_path=local_path, + ) + return client.get_model_version(registered_model_name, registered_model.version) + else: + return client.get_logged_model(model.model_id) + # model_info = mlflow_model.get_model_info() + # if registered_model is not None: + # model_info.registered_model_version = registered_model.version + + # return model_info def _copy_model_metadata_for_uc_sharing(local_path, flavor): diff --git a/mlflow/protos/internal.proto b/mlflow/protos/internal.proto index 614a1916c1415..057e18fea1c9c 100644 --- a/mlflow/protos/internal.proto +++ b/mlflow/protos/internal.proto @@ -20,4 +20,14 @@ enum InputVertexType { RUN = 1; DATASET = 2; + + MODEL = 3; +} + +// Types of vertices represented in MLflow Run Outputs. Valid vertices are MLflow objects that can +// have an output relationship. +enum OutputVertexType { + RUN_OUTPUT = 1; + + MODEL_OUTPUT = 2; } diff --git a/mlflow/protos/internal_pb2.py b/mlflow/protos/internal_pb2.py index 7fcf97a79455a..7aa93249e6ff5 100644 --- a/mlflow/protos/internal_pb2.py +++ b/mlflow/protos/internal_pb2.py @@ -19,7 +19,7 @@ from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*\'\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03*4\n\x10OutputVertexType\x12\x0e\n\nRUN_OUTPUT\x10\x01\x12\x10\n\x0cMODEL_OUTPUT\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,7 +28,9 @@ _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001' _globals['_INPUTVERTEXTYPE']._serialized_start=58 - _globals['_INPUTVERTEXTYPE']._serialized_end=97 + _globals['_INPUTVERTEXTYPE']._serialized_end=108 + _globals['_OUTPUTVERTEXTYPE']._serialized_start=110 + _globals['_OUTPUTVERTEXTYPE']._serialized_end=162 # @@protoc_insertion_point(module_scope) else: @@ -50,12 +52,17 @@ from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*\'\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0einternal.proto\x12\x0fmlflow.internal\x1a\x15scalapb/scalapb.proto*2\n\x0fInputVertexType\x12\x07\n\x03RUN\x10\x01\x12\x0b\n\x07\x44\x41TASET\x10\x02\x12\t\n\x05MODEL\x10\x03*4\n\x10OutputVertexType\x12\x0e\n\nRUN_OUTPUT\x10\x01\x12\x10\n\x0cMODEL_OUTPUT\x10\x02\x42#\n\x19org.mlflow.internal.proto\x90\x01\x01\xe2?\x02\x10\x01') _INPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['InputVertexType'] InputVertexType = enum_type_wrapper.EnumTypeWrapper(_INPUTVERTEXTYPE) + _OUTPUTVERTEXTYPE = DESCRIPTOR.enum_types_by_name['OutputVertexType'] + OutputVertexType = enum_type_wrapper.EnumTypeWrapper(_OUTPUTVERTEXTYPE) RUN = 1 DATASET = 2 + MODEL = 3 + RUN_OUTPUT = 1 + MODEL_OUTPUT = 2 if _descriptor._USE_C_DESCRIPTORS == False: @@ -63,6 +70,8 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\031org.mlflow.internal.proto\220\001\001\342?\002\020\001' _INPUTVERTEXTYPE._serialized_start=58 - _INPUTVERTEXTYPE._serialized_end=97 + _INPUTVERTEXTYPE._serialized_end=108 + _OUTPUTVERTEXTYPE._serialized_start=110 + _OUTPUTVERTEXTYPE._serialized_end=162 # @@protoc_insertion_point(module_scope) diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 2c79912cf47ce..0f47db32a67ad 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -2633,7 +2633,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) @trace_disabled # Suppress traces for internal predict calls while logging model def log_model( - artifact_path, + name=None, loader_module=None, data_path=None, code_path=None, # deprecated @@ -2653,6 +2653,11 @@ def log_model( example_no_conversion=None, streamable=None, resources: Optional[Union[str, List[Resource]]] = None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, + model_id: Optional[str] = None, ): """ Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow @@ -2665,7 +2670,7 @@ def log_model( and the parameters for the first workflow: ``python_model``, ``artifacts`` together. Args: - artifact_path: The run-relative artifact path to which to log the Python model. + name: The name of the model. loader_module: The name of the Python module that is used to load the model from ``data_path``. This module must define a method with the prototype ``_load_pyfunc(data_path)``. If not ``None``, this module and its @@ -2852,7 +2857,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: metadata of the logged model. """ return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.pyfunc, loader_module=loader_module, data_path=data_path, @@ -2873,6 +2878,11 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: streamable=streamable, resources=resources, infer_code_paths=infer_code_paths, + params=params, + tags=tags, + model_type=model_type, + step=step, + model_id=model_id, ) diff --git a/mlflow/pytorch/__init__.py b/mlflow/pytorch/__init__.py index 198ac1ff6c609..02c453745d3d5 100644 --- a/mlflow/pytorch/__init__.py +++ b/mlflow/pytorch/__init__.py @@ -137,7 +137,7 @@ def get_default_conda_env(): @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="torch")) def log_model( pytorch_model, - artifact_path, + name: Optional[str] = None, conda_env=None, code_paths=None, pickle_module=None, @@ -150,6 +150,11 @@ def log_model( pip_requirements=None, extra_pip_requirements=None, metadata=None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, + model_id: Optional[str] = None, **kwargs, ): """ @@ -177,7 +182,7 @@ class definition itself, should be included in one of the following locations: ``conda_env`` parameter. - One or more of the files specified by the ``code_paths`` parameter. - artifact_path: Run-relative artifact path. + name: The name of the model. conda_env: {{ conda_env }} code_paths: {{ code_paths }} pickle_module: The module that PyTorch should use to serialize ("pickle") the specified @@ -293,7 +298,7 @@ class definition itself, should be included in one of the following locations: """ pickle_module = pickle_module or mlflow_pytorch_pickle_module return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.pytorch, pytorch_model=pytorch_model, conda_env=conda_env, @@ -308,6 +313,11 @@ class definition itself, should be included in one of the following locations: pip_requirements=pip_requirements, extra_pip_requirements=extra_pip_requirements, metadata=metadata, + params=params, + tags=tags, + model_type=model_type, + step=step, + model_id=model_id, **kwargs, ) diff --git a/mlflow/sklearn/__init__.py b/mlflow/sklearn/__init__.py index 39b3cf7f5cc5b..7bc50a5f75b22 100644 --- a/mlflow/sklearn/__init__.py +++ b/mlflow/sklearn/__init__.py @@ -333,7 +333,7 @@ def save_model( @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) def log_model( sk_model, - artifact_path, + name: Optional[str] = None, conda_env=None, code_paths=None, serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE, @@ -345,6 +345,11 @@ def log_model( extra_pip_requirements=None, pyfunc_predict_fn="predict", metadata=None, + params: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + model_type: Optional[str] = None, + step: int = 0, + model_id: Optional[str] = None, ): """ Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model @@ -356,7 +361,7 @@ def log_model( Args: sk_model: scikit-learn model to be saved. - artifact_path: Run-relative artifact path. + name: Model name. conda_env: {{ conda_env }} code_paths: {{ code_paths }} serialization_format: The format in which to serialize the model. This should be one of @@ -410,7 +415,7 @@ def log_model( """ return Model.log( - artifact_path=artifact_path, + name=name, flavor=mlflow.sklearn, sk_model=sk_model, conda_env=conda_env, @@ -424,6 +429,11 @@ def log_model( extra_pip_requirements=extra_pip_requirements, pyfunc_predict_fn=pyfunc_predict_fn, metadata=metadata, + params=params, + tags=tags, + model_type=model_type, + step=step, + model_id=model_id, ) diff --git a/mlflow/store/artifact/models_artifact_repo.py b/mlflow/store/artifact/models_artifact_repo.py index f54a062f22716..614743200d147 100644 --- a/mlflow/store/artifact/models_artifact_repo.py +++ b/mlflow/store/artifact/models_artifact_repo.py @@ -90,8 +90,15 @@ def _get_model_uri_infos(uri): get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri() ) client = MlflowClient(registry_uri=databricks_profile_uri) - name, version = get_model_name_and_version(client, uri) - download_uri = client.get_model_version_download_uri(name, version) + name_and_version_or_id = get_model_name_and_version(client, uri) + if len(name_and_version_or_id) == 1: + name = None + version = None + model_id = name_and_version_or_id[0] + download_uri = client.get_logged_model(model_id).artifact_location + else: + name, version = name_and_version_or_id + download_uri = client.get_model_version_download_uri(name, version) return ( name, diff --git a/mlflow/store/artifact/utils/models.py b/mlflow/store/artifact/utils/models.py index e4347328d9e52..9e3cab029fc6b 100644 --- a/mlflow/store/artifact/utils/models.py +++ b/mlflow/store/artifact/utils/models.py @@ -37,7 +37,8 @@ def _get_latest_model_version(client, name, stage): class ParsedModelUri(NamedTuple): - name: str + model_id: Optional[str] = None + name: Optional[str] = None version: Optional[str] = None stage: Optional[str] = None alias: Optional[str] = None @@ -47,6 +48,7 @@ def _parse_model_uri(uri): """ Returns a ParsedModelUri tuple. Since a models:/ URI can only have one of {version, stage, 'latest', alias}, it will return + - (id, None, None, None) to look for a specific model by ID, - (name, version, None, None) to look for a specific version, - (name, None, stage, None) to look for the latest version of a stage, - (name, None, None, None) to look for the latest of all versions. @@ -77,16 +79,21 @@ def _parse_model_uri(uri): else: # The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production" return ParsedModelUri(name, stage=suffix) - else: + elif "@" in path: # The URI is an alias URI, e.g. "models:/AdsModel1@Champion" alias_parts = parts[0].rsplit("@", 1) if len(alias_parts) != 2 or alias_parts[1].strip() == "": raise MlflowException(_improper_model_uri_msg(uri)) return ParsedModelUri(alias_parts[0], alias=alias_parts[1]) + else: + # The URI is of the form "models:/" + return ParsedModelUri(parts[0]) def get_model_name_and_version(client, models_uri): - (model_name, model_version, model_stage, model_alias) = _parse_model_uri(models_uri) + (model_id, model_name, model_version, model_stage, model_alias) = _parse_model_uri(models_uri) + if model_id is not None: + return (model_id,) if model_version is not None: return model_name, model_version if model_alias is not None: diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index da5869fe5f825..13852e451a59c 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -5,7 +5,7 @@ import time import urllib from os.path import join -from typing import List +from typing import List, Optional from mlflow.entities.model_registry import ( ModelVersion, @@ -570,9 +570,23 @@ def _get_model_version_aliases(self, directory): return [alias.alias for alias in aliases if alias.version == version] def _get_file_model_version_from_dir(self, directory) -> FileModelVersion: + from mlflow.tracking.client import MlflowClient + meta = FileStore._read_yaml(directory, FileStore.META_DATA_FILE_NAME) meta["tags"] = self._get_model_version_tags_from_dir(directory) meta["aliases"] = self._get_model_version_aliases(directory) + # Fetch metrics and params from model ID + # + # TODO: Propagate tracking URI to file store directly, rather than relying on global + # URI (individual MlflowClient instances may have different tracking URIs) + if "model_id" in meta: + try: + model = MlflowClient().get_logged_model(meta["model_id"]) + meta["metrics"] = model.metrics + meta["params"] = model.params + except Exception: + # TODO: Make this exception handling more specific + pass return FileModelVersion.from_dictionary(meta) def _save_model_version_as_meta_file( @@ -605,6 +619,7 @@ def create_model_version( run_link=None, description=None, local_model_path=None, + model_id: Optional[str] = None, ) -> ModelVersion: """ Create a new model version from given source and run ID. @@ -617,6 +632,8 @@ def create_model_version( instances associated with this model version. run_link: Link to the run from an MLflow tracking server that generated this model. description: Description of the version. + model_id: The ID of the model (from an Experiment) that is being promoted to a + registered model version, if applicable. Returns: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion` @@ -639,9 +656,19 @@ def next_version(registered_model_name): if urllib.parse.urlparse(source).scheme == "models": parsed_model_uri = _parse_model_uri(source) try: - storage_location = self.get_model_version_download_uri( - parsed_model_uri.name, parsed_model_uri.version - ) + from mlflow.tracking.client import MlflowClient + + if parsed_model_uri.model_id is not None: + # TODO: Propagate tracking URI to file store directly, rather than relying on + # global URI (individual MlflowClient instances may have different tracking + # URIs) + model = MlflowClient().get_logged_model(parsed_model_uri.model_id) + storage_location = model.artifact_location + run_id = run_id or model.run_id + else: + storage_location = self.get_model_version_download_uri( + parsed_model_uri.name, parsed_model_uri.version + ) except Exception as e: raise MlflowException( f"Unable to fetch model from model URI source artifact location '{source}'." @@ -667,6 +694,7 @@ def next_version(registered_model_name): tags=tags, aliases=[], storage_location=storage_location, + model_id=model_id, ) model_version_dir = self._get_model_version_dir(name, version) mkdir(model_version_dir) @@ -677,7 +705,7 @@ def next_version(registered_model_name): if tags is not None: for tag in tags: self.set_model_version_tag(name, version, tag) - return model_version.to_mlflow_entity() + return self.get_model_version(name, version) except Exception as e: more_retries = self.CREATE_MODEL_VERSION_RETRIES - attempt - 1 logging.warning( diff --git a/mlflow/store/tracking/file_store.py b/mlflow/store/tracking/file_store.py index c2bdd63a09d1a..1a4cf44ac1f8c 100644 --- a/mlflow/store/tracking/file_store.py +++ b/mlflow/store/tracking/file_store.py @@ -5,8 +5,9 @@ import sys import time import uuid +from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple from mlflow.entities import ( Dataset, @@ -14,12 +15,19 @@ Experiment, ExperimentTag, InputTag, + LoggedModel, Metric, + ModelInput, + ModelOutput, + ModelParam, + ModelStatus, + ModelTag, Param, Run, RunData, RunInfo, RunInputs, + RunOutputs, RunStatus, RunTag, SourceType, @@ -38,7 +46,7 @@ INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST, ) -from mlflow.protos.internal_pb2 import InputVertexType +from mlflow.protos.internal_pb2 import InputVertexType, OutputVertexType from mlflow.store.entities.paged_list import PagedList from mlflow.store.model_registry.file_store import FileStore as ModelRegistryFileStore from mlflow.store.tracking import ( @@ -159,6 +167,7 @@ class FileStore(AbstractStore): EXPERIMENT_TAGS_FOLDER_NAME = "tags" DATASETS_FOLDER_NAME = "datasets" INPUTS_FOLDER_NAME = "inputs" + OUTPUTS_FOLDER_NAME = "outputs" META_DATA_FILE_NAME = "meta.yaml" DEFAULT_EXPERIMENT_ID = "0" TRACE_INFO_FILE_NAME = "trace_info.yaml" @@ -170,6 +179,7 @@ class FileStore(AbstractStore): DATASETS_FOLDER_NAME, TRACES_FOLDER_NAME, ] + MODELS_FOLDER_NAME = "models" def __init__(self, root_directory=None, artifact_root_uri=None): """ @@ -237,6 +247,12 @@ def _get_metric_path(self, experiment_id, run_uuid, metric_key): self._get_run_dir(experiment_id, run_uuid), FileStore.METRICS_FOLDER_NAME, metric_key ) + def _get_model_metric_path(self, experiment_id: str, model_id: str, metric_key: str) -> str: + _validate_metric_name(metric_key) + return os.path.join( + self._get_model_dir(experiment_id, model_id), FileStore.METRICS_FOLDER_NAME, metric_key + ) + def _get_param_path(self, experiment_id, run_uuid, param_name): _validate_run_id(run_uuid) _validate_param_name(param_name) @@ -687,11 +703,12 @@ def _get_run_from_info(self, run_info): params = self._get_all_params(run_info) tags = self._get_all_tags(run_info) inputs: RunInputs = self._get_all_inputs(run_info) + outputs: RunOutputs = self._get_all_outputs(run_info) if not run_info.run_name: run_name = _get_run_name_from_tags(tags) if run_name: run_info._set_run_name(run_name) - return Run(run_info, RunData(metrics, params, tags), inputs) + return Run(run_info, RunData(metrics, params, tags), inputs, outputs) def _get_run_info(self, run_uuid): """ @@ -751,10 +768,12 @@ def _get_resource_files(self, root_dir, subfolder_name): return source_dirs[0], file_names @staticmethod - def _get_metric_from_file(parent_path, metric_name, exp_id): + def _get_metric_from_file( + parent_path: str, metric_name: str, run_id: str, exp_id: str + ) -> Metric: _validate_metric_name(metric_name) metric_objs = [ - FileStore._get_metric_from_line(metric_name, line, exp_id) + FileStore._get_metric_from_line(run_id, metric_name, line, exp_id) for line in read_file_lines(parent_path, metric_name) ] if len(metric_objs) == 0: @@ -775,24 +794,38 @@ def _get_all_metrics(self, run_info): metrics = [] for metric_file in metric_files: metrics.append( - self._get_metric_from_file(parent_path, metric_file, run_info.experiment_id) + self._get_metric_from_file( + parent_path, metric_file, run_info.run_id, run_info.experiment_id + ) ) return metrics @staticmethod - def _get_metric_from_line(metric_name, metric_line, exp_id): + def _get_metric_from_line( + run_id: str, metric_name: str, metric_line: str, exp_id: str + ) -> Metric: metric_parts = metric_line.strip().split(" ") - if len(metric_parts) != 2 and len(metric_parts) != 3: + if len(metric_parts) != 2 and len(metric_parts) != 3 and len(metric_parts) != 5: raise MlflowException( f"Metric '{metric_name}' is malformed; persisted metric data contained " - f"{len(metric_parts)} fields. Expected 2 or 3 fields. " + f"{len(metric_parts)} fields. Expected 2, 3, or 5 fields. " f"Experiment id: {exp_id}", databricks_pb2.INTERNAL_ERROR, ) ts = int(metric_parts[0]) val = float(metric_parts[1]) step = int(metric_parts[2]) if len(metric_parts) == 3 else 0 - return Metric(key=metric_name, value=val, timestamp=ts, step=step) + dataset_name = str(metric_parts[3]) if len(metric_parts) == 5 else None + dataset_digest = str(metric_parts[4]) if len(metric_parts) == 5 else None + return Metric( + key=metric_name, + value=val, + timestamp=ts, + step=step, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + run_id=run_id, + ) def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None): """ @@ -831,7 +864,7 @@ def get_metric_history(self, run_id, metric_key, max_results=None, page_token=No return PagedList([], None) return PagedList( [ - FileStore._get_metric_from_line(metric_key, line, run_info.experiment_id) + FileStore._get_metric_from_line(run_id, metric_key, line, run_info.experiment_id) for line in read_file_lines(parent_path, metric_key) ], None, @@ -945,17 +978,45 @@ def _search_runs( runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token - def log_metric(self, run_id, metric): + def log_metric(self, run_id: str, metric: Metric): _validate_run_id(run_id) _validate_metric(metric.key, metric.value, metric.timestamp, metric.step) run_info = self._get_run_info(run_id) check_run_is_active(run_info) self._log_run_metric(run_info, metric) + if metric.model_id is not None: + self._log_model_metric( + experiment_id=run_info.experiment_id, + model_id=metric.model_id, + run_id=run_id, + metric=metric, + ) def _log_run_metric(self, run_info, metric): metric_path = self._get_metric_path(run_info.experiment_id, run_info.run_id, metric.key) make_containing_dirs(metric_path) - append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n") + if metric.dataset_name is not None and metric.dataset_digest is not None: + append_to( + metric_path, + f"{metric.timestamp} {metric.value} {metric.step} {metric.dataset_name} " + f"{metric.dataset_digest}\n", + ) + else: + append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step}\n") + + def _log_model_metric(self, experiment_id: str, model_id: str, run_id: str, metric: Metric): + metric_path = self._get_model_metric_path( + experiment_id=experiment_id, model_id=model_id, metric_key=metric.key + ) + make_containing_dirs(metric_path) + if metric.dataset_name is not None and metric.dataset_digest is not None: + append_to( + metric_path, + f"{metric.timestamp} {metric.value} {metric.step} {run_id} {metric.dataset_name} " + f"{metric.dataset_digest}\n", + ) + else: + append_to(metric_path, f"{metric.timestamp} {metric.value} {metric.step} {run_id}\n") def _writeable_value(self, tag_value): if tag_value is None: @@ -1077,6 +1138,13 @@ def log_batch(self, run_id, metrics, params, tags): self._log_run_param(run_info, param) for metric in metrics: self._log_run_metric(run_info, metric) + if metric.model_id is not None: + self._log_model_metric( + experiment_id=run_info.experiment_id, + model_id=metric.model_id, + run_id=run_id, + metric=metric, + ) for tag in tags: # NB: If the tag run name value is set, update the run info to assure # synchronization. @@ -1112,14 +1180,21 @@ def record_logged_model(self, run_id, mlflow_model): except Exception as e: raise MlflowException(e, INTERNAL_ERROR) - def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None): + def log_inputs( + self, + run_id: str, + datasets: Optional[List[DatasetInput]] = None, + models: Optional[List[ModelInput]] = None, + ): """ - Log inputs, such as datasets, to the specified run. + Log inputs, such as datasets and models, to the specified run. Args: run_id: String id for the run datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log as inputs to the run. + models: List of :py:class:`mlflow.entities.ModelInput` instances to log + as inputs to the run. Returns: None. @@ -1128,13 +1203,13 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) run_info = self._get_run_info(run_id) check_run_is_active(run_info) - if datasets is None: + if datasets is None and models is None: return experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) run_dir = self._get_run_dir(run_info.experiment_id, run_id) - for dataset_input in datasets: + for dataset_input in datasets or []: dataset = dataset_input.dataset dataset_id = FileStore._get_dataset_id( dataset_name=dataset.name, dataset_digest=dataset.digest @@ -1144,7 +1219,7 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) os.makedirs(dataset_dir, exist_ok=True) write_yaml(dataset_dir, FileStore.META_DATA_FILE_NAME, dict(dataset)) - input_id = FileStore._get_input_id(dataset_id=dataset_id, run_id=run_id) + input_id = FileStore._get_dataset_input_id(dataset_id=dataset_id, run_id=run_id) input_dir = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME, input_id) if not os.path.exists(input_dir): os.makedirs(input_dir, exist_ok=True) @@ -1157,6 +1232,57 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) ) fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME) + for model_input in models or []: + model_id = model_input.model_id + input_id = FileStore._get_model_input_id(model_id=model_id, run_id=run_id) + input_dir = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME, input_id) + if not os.path.exists(input_dir): + os.makedirs(input_dir, exist_ok=True) + fs_input = FileStore._FileStoreInput( + source_type=InputVertexType.MODEL, + source_id=model_id, + destination_type=InputVertexType.RUN, + destination_id=run_id, + tags={}, + ) + fs_input.write_yaml(input_dir, FileStore.META_DATA_FILE_NAME) + + def log_outputs(self, run_id, models: Optional[List[ModelOutput]] = None): + """ + Log outputs, such as models, to the specified run. + + Args: + run_id: String id for the run + models: List of :py:class:`mlflow.entities.ModelOutput` instances to log + as outputs of the run. + + Returns: + None. + """ + _validate_run_id(run_id) + run_info = self._get_run_info(run_id) + check_run_is_active(run_info) + + if models is None: + return + + run_dir = self._get_run_dir(run_info.experiment_id, run_id) + + for model_output in models: + model_id = model_output.model_id + output_dir = os.path.join(run_dir, FileStore.OUTPUTS_FOLDER_NAME, model_id) + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + fs_output = FileStore._FileStoreOutput( + source_type=OutputVertexType.RUN_OUTPUT, + source_id=model_id, + destination_type=OutputVertexType.MODEL_OUTPUT, + destination_id=run_id, + tags={}, + step=model_output.step, + ) + fs_output.write_yaml(output_dir, FileStore.META_DATA_FILE_NAME) + @staticmethod def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str: md5 = insecure_hash.md5(dataset_name.encode("utf-8")) @@ -1164,11 +1290,17 @@ def _get_dataset_id(dataset_name: str, dataset_digest: str) -> str: return md5.hexdigest() @staticmethod - def _get_input_id(dataset_id: str, run_id: str) -> str: + def _get_dataset_input_id(dataset_id: str, run_id: str) -> str: md5 = insecure_hash.md5(dataset_id.encode("utf-8")) md5.update(run_id.encode("utf-8")) return md5.hexdigest() + @staticmethod + def _get_model_input_id(model_id: str, run_id: str) -> str: + md5 = insecure_hash.md5(model_id.encode("utf-8")) + md5.update(run_id.encode("utf-8")) + return md5.hexdigest() + class _FileStoreInput(NamedTuple): source_type: int source_id: str @@ -1197,13 +1329,54 @@ def from_yaml(cls, root, file_name): tags=dict_from_yaml["tags"], ) + class _FileStoreOutput(NamedTuple): + source_type: int + source_id: str + destination_type: int + destination_id: str + tags: Dict[str, str] + step: int + + def write_yaml(self, root: str, file_name: str): + dict_for_yaml = { + "source_type": OutputVertexType.Name(self.source_type), + "source_id": self.source_id, + "destination_type": OutputVertexType.Name(self.destination_type), + "destination_id": self.source_id, + "tags": self.tags, + "step": self.step, + } + write_yaml(root, file_name, dict_for_yaml) + + @classmethod + def from_yaml(cls, root, file_name): + dict_from_yaml = FileStore._read_yaml(root, file_name) + return cls( + source_type=OutputVertexType.Value(dict_from_yaml["source_type"]), + source_id=dict_from_yaml["source_id"], + destination_type=OutputVertexType.Value(dict_from_yaml["destination_type"]), + destination_id=dict_from_yaml["destination_id"], + tags=dict_from_yaml["tags"], + step=dict_from_yaml["step"], + ) + def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id) inputs_parent_path = os.path.join(run_dir, FileStore.INPUTS_FOLDER_NAME) + if not os.path.exists(inputs_parent_path): + return RunInputs(dataset_inputs=[], model_inputs=[]) + experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) - datasets_parent_path = os.path.join(experiment_dir, FileStore.DATASETS_FOLDER_NAME) - if not os.path.exists(inputs_parent_path) or not os.path.exists(datasets_parent_path): - return RunInputs(dataset_inputs=[]) + dataset_inputs = self._get_dataset_inputs(run_info, inputs_parent_path, experiment_dir) + model_inputs = self._get_model_inputs(inputs_parent_path, experiment_dir) + return RunInputs(dataset_inputs=dataset_inputs, model_inputs=model_inputs) + + def _get_dataset_inputs( + self, run_info: RunInfo, inputs_parent_path: str, experiment_dir_path: str + ) -> List[DatasetInput]: + datasets_parent_path = os.path.join(experiment_dir_path, FileStore.DATASETS_FOLDER_NAME) + if not os.path.exists(datasets_parent_path): + return [] dataset_dirs = os.listdir(datasets_parent_path) dataset_inputs = [] @@ -1213,9 +1386,6 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: input_dir_full_path, FileStore.META_DATA_FILE_NAME ) if fs_input.source_type != InputVertexType.DATASET: - logging.warning( - f"Encountered invalid run input source type '{fs_input.source_type}'. Skipping." - ) continue matching_dataset_dirs = [d for d in dataset_dirs if d == fs_input.source_id] @@ -1238,7 +1408,51 @@ def _get_all_inputs(self, run_info: RunInfo) -> RunInputs: ) dataset_inputs.append(dataset_input) - return RunInputs(dataset_inputs=dataset_inputs) + return dataset_inputs + + def _get_model_inputs( + self, inputs_parent_path: str, experiment_dir_path: str + ) -> List[ModelInput]: + model_inputs = [] + for input_dir in os.listdir(inputs_parent_path): + input_dir_full_path = os.path.join(inputs_parent_path, input_dir) + fs_input = FileStore._FileStoreInput.from_yaml( + input_dir_full_path, FileStore.META_DATA_FILE_NAME + ) + if fs_input.source_type != InputVertexType.MODEL: + continue + + model_input = ModelInput(model_id=fs_input.source_id) + model_inputs.append(model_input) + + return model_inputs + + def _get_all_outputs(self, run_info: RunInfo) -> RunOutputs: + run_dir = self._get_run_dir(run_info.experiment_id, run_info.run_id) + outputs_parent_path = os.path.join(run_dir, FileStore.OUTPUTS_FOLDER_NAME) + if not os.path.exists(outputs_parent_path): + return RunOutputs(model_outputs=[]) + + experiment_dir = self._get_experiment_path(run_info.experiment_id, assert_exists=True) + model_outputs = self._get_model_outputs(outputs_parent_path, experiment_dir) + return RunOutputs(model_outputs=model_outputs) + + def _get_model_outputs( + self, outputs_parent_path: str, experiment_dir: str + ) -> List[ModelOutput]: + model_outputs = [] + for output_dir in os.listdir(outputs_parent_path): + output_dir_full_path = os.path.join(outputs_parent_path, output_dir) + fs_output = FileStore._FileStoreOutput.from_yaml( + output_dir_full_path, FileStore.META_DATA_FILE_NAME + ) + if fs_output.destination_type != OutputVertexType.MODEL_OUTPUT: + continue + + model_output = ModelOutput(model_id=fs_output.destination_id, step=fs_output.step) + model_outputs.append(model_output) + + return model_outputs def _search_datasets(self, experiment_ids) -> List[_DatasetSummary]: """ @@ -1686,3 +1900,294 @@ def _list_trace_infos(self, experiment_id): exc_info=_logger.isEnabledFor(logging.DEBUG), ) return trace_infos + + def create_logged_model( + self, + experiment_id: str, + name: str, + run_id: Optional[str] = None, + tags: Optional[List[ModelTag]] = None, + params: Optional[List[ModelParam]] = None, + model_type: Optional[str] = None, + ) -> LoggedModel: + """ + Create a new model. + + Args: + experiment_id: ID of the Experiment where the model is being created. + name: Name of the model. + run_id: Run ID where the model is being created from. + tags: Key-value tags for the model. + params: Key-value params for the model. + + Returns: + The model version. + """ + experiment_id = FileStore.DEFAULT_EXPERIMENT_ID if experiment_id is None else experiment_id + experiment = self.get_experiment(experiment_id) + if experiment is None: + raise MlflowException( + "Could not create model under experiment with ID %s - no such experiment " + "exists." % experiment_id, + databricks_pb2.RESOURCE_DOES_NOT_EXIST, + ) + if experiment.lifecycle_stage != LifecycleStage.ACTIVE: + raise MlflowException( + f"Could not create model under non-active experiment with ID {experiment_id}.", + databricks_pb2.INVALID_STATE, + ) + for param in params or []: + _validate_param(param.key, param.value) + + model_id = str(uuid.uuid4()) + artifact_location = self._get_model_artifact_dir(experiment_id, model_id) + creation_timestamp = int(time.time() * 1000) + model = LoggedModel( + experiment_id=experiment_id, + model_id=model_id, + name=name, + artifact_location=artifact_location, + creation_timestamp=creation_timestamp, + last_updated_timestamp=creation_timestamp, + run_id=run_id, + status=ModelStatus.PENDING, + tags=tags, + params=params, + model_type=model_type, + ) + + # Persist model metadata and create directories for logging metrics, tags + model_dir = self._get_model_dir(experiment_id, model_id) + mkdir(model_dir) + model_info_dict: Dict[str, Any] = self._make_persisted_model_dict(model) + write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict) + mkdir(model_dir, FileStore.METRICS_FOLDER_NAME) + for tag in tags or []: + self.set_logged_model_tag(model_id=model_id, tag=tag) + + return self.get_logged_model(model_id=model_id) + + def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel: + """ + Finalize a model by updating its status. + + Args: + model_id: ID of the model to finalize. + status: Final status to set on the model. + + Returns: + The updated model. + """ + if status != ModelStatus.READY: + raise MlflowException( + f"Invalid model status: {status}. Expected statuses: [{ModelStatus.READY}]", + databricks_pb2.INVALID_PARAMETER_VALUE, + ) + model_dict = self._get_model_dict(model_id) + model = LoggedModel.from_dictionary(model_dict) + model.status = status + model.last_updated_timestamp = int(time.time() * 1000) + model_dir = self._get_model_dir(model.experiment_id, model.model_id) + model_info_dict = self._make_persisted_model_dict(model) + write_yaml(model_dir, FileStore.META_DATA_FILE_NAME, model_info_dict, overwrite=True) + return self.get_logged_model(model_id) + + def set_logged_model_tag(self, model_id: str, tag: ModelTag): + _validate_tag_name(tag.key) + model = self.get_logged_model(model_id) + tag_path = os.path.join( + self._get_model_dir(model.experiment_id, model.model_id), + FileStore.TAGS_FOLDER_NAME, + tag.key, + ) + make_containing_dirs(tag_path) + # Don't add trailing newline + write_to(tag_path, self._writeable_value(tag.value)) + + def get_logged_model(self, model_id: str) -> LoggedModel: + return LoggedModel.from_dictionary(self._get_model_dict(model_id)) + + def _get_model_artifact_dir(self, experiment_id: str, model_id: str) -> str: + return append_to_uri_path( + self.get_experiment(experiment_id).artifact_location, + FileStore.MODELS_FOLDER_NAME, + model_id, + FileStore.ARTIFACTS_FOLDER_NAME, + ) + + def _make_persisted_model_dict(self, model: LoggedModel) -> Dict[str, Any]: + model_dict = model.to_dictionary() + model_dict.pop("tags", None) + model_dict.pop("metrics", None) + return model_dict + + def _get_model_dict(self, model_id: str) -> Dict[str, Any]: + exp_id, model_dir = self._find_model_root(model_id) + if model_dir is None: + raise MlflowException( + f"Model '{model_id}' not found", databricks_pb2.RESOURCE_DOES_NOT_EXIST + ) + model_dict: Dict[str, Any] = self._get_model_info_from_dir(model_dir) + if model_dict["experiment_id"] != exp_id: + raise MlflowException( + f"Model '{model_id}' metadata is in invalid state.", databricks_pb2.INVALID_STATE + ) + return model_dict + + def _get_model_dir(self, experiment_id: str, model_id: str) -> str: + if not self._has_experiment(experiment_id): + return None + return os.path.join( + self._get_experiment_path(experiment_id, assert_exists=True), + FileStore.MODELS_FOLDER_NAME, + model_id, + ) + + def _find_model_root(self, model_id): + self._check_root_dir() + all_experiments = self._get_active_experiments(False) + self._get_deleted_experiments(False) + for experiment_dir in all_experiments: + models_dir_path = os.path.join( + self.root_directory, experiment_dir, FileStore.MODELS_FOLDER_NAME + ) + models = find(models_dir_path, model_id, full_path=True) + if len(models) == 0: + continue + return os.path.basename(os.path.dirname(os.path.abspath(models_dir_path))), models[0] + return None, None + + def _get_model_from_dir(self, model_dir: str) -> LoggedModel: + return LoggedModel.from_dictionary(self._get_model_info_from_dir(model_dir)) + + def _get_model_info_from_dir(self, model_dir: str) -> Dict[str, Any]: + model_dict = FileStore._read_yaml(model_dir, FileStore.META_DATA_FILE_NAME) + model_dict["tags"] = self._get_all_model_tags(model_dir) + model_dict["metrics"] = self._get_all_model_metrics( + model_id=model_dict["model_id"], model_dir=model_dir + ) + return model_dict + + def _get_all_model_tags(self, model_dir: str) -> List[ModelTag]: + parent_path, tag_files = self._get_resource_files(model_dir, FileStore.TAGS_FOLDER_NAME) + tags = [] + for tag_file in tag_files: + tags.append(self._get_tag_from_file(parent_path, tag_file)) + return tags + + def _get_all_model_metrics(self, model_id: str, model_dir: str) -> List[Metric]: + parent_path, metric_files = self._get_resource_files( + model_dir, FileStore.METRICS_FOLDER_NAME + ) + metrics = [] + for metric_file in metric_files: + metrics.extend( + FileStore._get_model_metrics_from_file( + model_id=model_id, parent_path=parent_path, metric_name=metric_file + ) + ) + return metrics + + @staticmethod + def _get_model_metrics_from_file( + model_id: str, parent_path: str, metric_name: str + ) -> List[Metric]: + _validate_metric_name(metric_name) + metric_objs = [ + FileStore._get_model_metric_from_line(model_id, metric_name, line) + for line in read_file_lines(parent_path, metric_name) + ] + if len(metric_objs) == 0: + raise ValueError(f"Metric '{metric_name}' is malformed. No data found.") + + # Group metrics by (dataset_name, dataset_digest) + grouped_metrics = defaultdict(list) + for metric in metric_objs: + key = (metric.dataset_name, metric.dataset_digest) + grouped_metrics[key].append(metric) + + # Compute the max for each group + return [ + max(group, key=lambda m: (m.step, m.timestamp, m.value)) + for group in grouped_metrics.values() + ] + + @staticmethod + def _get_model_metric_from_line(model_id: str, metric_name: str, metric_line: str) -> Metric: + metric_parts = metric_line.strip().split(" ") + if len(metric_parts) not in [4, 6]: + raise MlflowException( + f"Metric '{metric_name}' is malformed; persisted metric data contained " + f"{len(metric_parts)} fields. Expected 4 or 6 fields.", + databricks_pb2.INTERNAL_ERROR, + ) + ts = int(metric_parts[0]) + val = float(metric_parts[1]) + step = int(metric_parts[2]) + run_id = str(metric_parts[3]) + dataset_name = str(metric_parts[4]) if len(metric_parts) == 6 else None + dataset_digest = str(metric_parts[5]) if len(metric_parts) == 6 else None + # TODO: Read run ID from the metric file and pass it to the Metric constructor + return Metric( + key=metric_name, + value=val, + timestamp=ts, + step=step, + model_id=model_id, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + run_id=run_id, + ) + + def search_logged_models( + self, + experiment_ids: List[str], + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + ) -> List[LoggedModel]: + all_models = [] + for experiment_id in experiment_ids: + models = self._list_models(experiment_id) + all_models.extend(models) + filtered = SearchUtils.filter_logged_models(models, filter_string) + return SearchUtils.sort_logged_models(filtered, order_by)[:max_results] + + def _list_models(self, experiment_id: str) -> List[LoggedModel]: + self._check_root_dir() + if not self._has_experiment(experiment_id): + return [] + experiment_dir = self._get_experiment_path(experiment_id, assert_exists=True) + model_dirs = list_all( + os.path.join(experiment_dir, FileStore.MODELS_FOLDER_NAME), + filter_func=lambda x: all( + os.path.basename(os.path.normpath(x)) != reservedFolderName + for reservedFolderName in FileStore.RESERVED_EXPERIMENT_FOLDERS + ) + and os.path.isdir(x), + full_path=True, + ) + models = [] + for m_dir in model_dirs: + try: + # trap and warn known issues, will raise unexpected exceptions to caller + model = self._get_model_from_dir(m_dir) + if model.experiment_id != experiment_id: + logging.warning( + "Wrong experiment ID (%s) recorded for model '%s'. " + "It should be %s. Model will be ignored.", + str(model.experiment_id), + str(model.model_id), + str(experiment_id), + exc_info=True, + ) + continue + models.append(model) + except MissingConfigException as exc: + # trap malformed model exception and log + # this is at debug level because if the same store is used for + # artifact storage, it's common the folder is not a run folder + m_id = os.path.basename(m_dir) + logging.debug( + "Malformed model '%s'. Detailed error %s", m_id, str(exc), exc_info=True + ) + return models diff --git a/mlflow/tracing/constant.py b/mlflow/tracing/constant.py index 6fd5c0c025178..68eaa47c20aa9 100644 --- a/mlflow/tracing/constant.py +++ b/mlflow/tracing/constant.py @@ -3,6 +3,7 @@ class TraceMetadataKey: INPUTS = "mlflow.traceInputs" OUTPUTS = "mlflow.traceOutputs" SOURCE_RUN = "mlflow.sourceRun" + MODEL_ID = "mlflow.modelId" class TraceTagKey: @@ -19,6 +20,7 @@ class SpanAttributeKey: OUTPUTS = "mlflow.spanOutputs" SPAN_TYPE = "mlflow.spanType" FUNCTION_NAME = "mlflow.spanFunctionName" + MODEL_ID = "mlflow.modelId" # All storage backends are guaranteed to support key values up to 250 characters diff --git a/mlflow/tracing/display/display_handler.py b/mlflow/tracing/display/display_handler.py index f8bff89122048..5da9553d2b618 100644 --- a/mlflow/tracing/display/display_handler.py +++ b/mlflow/tracing/display/display_handler.py @@ -90,6 +90,10 @@ def get_mimebundle(self, traces: List[Trace]): } def display_traces(self, traces: List[Trace]): + # Temporarily disable rendering of traces in Databricks notebooks, + # since it doesnt' work with file-based storage + return + # This only works in Databricks notebooks if not is_in_databricks_runtime(): return diff --git a/mlflow/tracing/fluent.py b/mlflow/tracing/fluent.py index c1c7d66306de8..dc0932acd05d1 100644 --- a/mlflow/tracing/fluent.py +++ b/mlflow/tracing/fluent.py @@ -58,6 +58,7 @@ def trace( name: Optional[str] = None, span_type: str = SpanType.UNKNOWN, attributes: Optional[Dict[str, Any]] = None, + model_id: Optional[str] = None, ) -> Callable: """ A decorator that creates a new span for the decorated function. @@ -135,7 +136,9 @@ class _WrappingContext: def _wrapping_logic(fn, args, kwargs): span_name = name or fn.__name__ - with start_span(name=span_name, span_type=span_type, attributes=attributes) as span: + with start_span( + name=span_name, span_type=span_type, attributes=attributes, model_id=model_id + ) as span: span.set_attribute(SpanAttributeKey.FUNCTION_NAME, fn.__name__) try: span.set_inputs(capture_function_input_args(fn, args, kwargs)) @@ -184,6 +187,7 @@ def start_span( name: str = "span", span_type: Optional[str] = SpanType.UNKNOWN, attributes: Optional[Dict[str, Any]] = None, + model_id: Optional[str] = None, ) -> Generator[LiveSpan, None, None]: """ Context manager to create a new span and start it as the current span in the context. @@ -253,9 +257,11 @@ def start_span( # Create a new MLflow span and register it to the in-memory trace manager request_id = get_otel_attribute(otel_span, SpanAttributeKey.REQUEST_ID) mlflow_span = create_mlflow_span(otel_span, request_id, span_type) - mlflow_span.set_attributes(attributes or {}) + attributes = dict(attributes) if attributes is not None else {} + if model_id is not None: + attributes[SpanAttributeKey.MODEL_ID] = model_id + mlflow_span.set_attributes(attributes) InMemoryTraceManager.get_instance().register_span(mlflow_span) - except Exception as e: _logger.warning( f"Failed to start span: {e}. For full traceback, set logging level to debug.", @@ -332,6 +338,7 @@ def search_traces( max_results: Optional[int] = None, order_by: Optional[List[str]] = None, extract_fields: Optional[List[str]] = None, + model_id: Optional[str] = None, ) -> "pandas.DataFrame": """ Return traces that match the given list of search expressions within the experiments. @@ -430,6 +437,7 @@ def pagination_wrapper_func(number_to_get, next_page_token): filter_string=filter_string, order_by=order_by, page_token=next_page_token, + model_id=model_id, ) results = get_results_from_paginated_fn( diff --git a/mlflow/tracing/processor/mlflow.py b/mlflow/tracing/processor/mlflow.py index bed4c48f275b4..21decbda90ffa 100644 --- a/mlflow/tracing/processor/mlflow.py +++ b/mlflow/tracing/processor/mlflow.py @@ -145,17 +145,21 @@ def on_end(self, span: OTelReadableSpan) -> None: return request_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID) + # TODO: We should remove the model ID from the span attributes + model_id = get_otel_attribute(span, SpanAttributeKey.MODEL_ID) with self._trace_manager.get_trace(request_id) as trace: if trace is None: _logger.debug(f"Trace data with request ID {request_id} not found.") return - self._update_trace_info(trace, span) + self._update_trace_info(trace, span, model_id) deduplicate_span_names_in_place(list(trace.span_dict.values())) super().on_end(span) - def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan): + def _update_trace_info( + self, trace: _Trace, root_span: OTelReadableSpan, model_id: Optional[str] + ): """Update the trace info with the final values from the root span.""" # The trace/span start time needs adjustment to exclude the latency of # the backend API call. We already adjusted the span start time in the @@ -173,6 +177,8 @@ def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan): ), } ) + if model_id is not None: + trace.info.request_metadata[SpanAttributeKey.MODEL_ID] = model_id def _truncate_metadata(self, value: Optional[str]) -> str: """Get truncated value of the attribute if it exceeds the maximum length.""" diff --git a/mlflow/tracking/_model_registry/client.py b/mlflow/tracking/_model_registry/client.py index 5cbe391b7debe..d3773bc029077 100644 --- a/mlflow/tracking/_model_registry/client.py +++ b/mlflow/tracking/_model_registry/client.py @@ -4,6 +4,7 @@ exposed in the :py:mod:`mlflow.tracking` module. """ import logging +from typing import Optional from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag from mlflow.exceptions import MlflowException @@ -188,6 +189,7 @@ def create_model_version( description=None, await_creation_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, local_model_path=None, + model_id: Optional[str] = None, ): """Create a new model version from given source. @@ -202,6 +204,8 @@ def create_model_version( await_creation_for: Number of seconds to wait for the model version to finish being created and is in ``READY`` status. By default, the function waits for five minutes. Specify 0 or None to skip waiting. + model_id: The ID of the model (from an Experiment) that is being promoted to a + registered model version, if applicable. Returns: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by @@ -220,6 +224,7 @@ def create_model_version( run_link, description, local_model_path=local_model_path, + model_id=model_id, ) else: # Fall back to calling create_model_version without diff --git a/mlflow/tracking/_model_registry/fluent.py b/mlflow/tracking/_model_registry/fluent.py index 50892ab28063c..8d5b66c9d2de2 100644 --- a/mlflow/tracking/_model_registry/fluent.py +++ b/mlflow/tracking/_model_registry/fluent.py @@ -4,6 +4,7 @@ from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import ALREADY_EXISTS, RESOURCE_ALREADY_EXISTS, ErrorCode from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository +from mlflow.store.artifact.utils.models import _parse_model_uri from mlflow.store.model_registry import ( SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT, SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, @@ -75,7 +76,10 @@ def register_model( Version: 1 """ return _register_model( - model_uri=model_uri, name=name, await_registration_for=await_registration_for, tags=tags + model_uri=model_uri, + name=name, + await_registration_for=await_registration_for, + tags=tags, ) @@ -109,6 +113,7 @@ def _register_model( source = RunsArtifactRepository.get_underlying_uri(model_uri) (run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri) + parsed_model_uri = _parse_model_uri(model_uri) create_version_response = client._create_model_version( name=name, source=source, @@ -116,6 +121,7 @@ def _register_model( tags=tags, await_creation_for=await_registration_for, local_model_path=local_model_path, + model_id=parsed_model_uri.model_id, ) eprint( f"Created version '{create_version_response.version}' of model " diff --git a/mlflow/tracking/_tracking_service/client.py b/mlflow/tracking/_tracking_service/client.py index 70f5dd38681c0..f7a0c3f95e494 100644 --- a/mlflow/tracking/_tracking_service/client.py +++ b/mlflow/tracking/_tracking_service/client.py @@ -13,7 +13,13 @@ from mlflow.entities import ( ExperimentTag, + LoggedModel, Metric, + ModelInput, + ModelOutput, + ModelParam, + ModelStatus, + ModelTag, Param, RunStatus, RunTag, @@ -32,6 +38,7 @@ MlflowTraceDataNotFound, ) from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE, ErrorCode +from mlflow.store.artifact.artifact_repo import ArtifactRepository from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.entities.paged_list import PagedList from mlflow.store.tracking import ( @@ -296,7 +303,18 @@ def _search_traces( max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS, order_by: Optional[List[str]] = None, page_token: Optional[str] = None, + model_id: Optional[str] = None, ): + if model_id is not None: + if filter_string: + raise MlflowException( + message=( + "Cannot specify both `model_id` and `experiment_ids` or `filter_string`" + " in the search_traces call." + ), + error_code=INVALID_PARAMETER_VALUE, + ) + filter_string = f"request_metadata.`mlflow.modelId` = '{model_id}'" return self.store.search_traces( experiment_ids=experiment_ids, filter_string=filter_string, @@ -312,6 +330,7 @@ def search_traces( max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS, order_by: Optional[List[str]] = None, page_token: Optional[str] = None, + model_id: Optional[str] = None, ) -> PagedList[Trace]: def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]: """ @@ -343,6 +362,7 @@ def download_trace_data(trace_info: TraceInfo) -> Optional[Trace]: max_results=next_max_results, order_by=order_by, page_token=next_token, + model_id=model_id, ) traces.extend(t for t in executor.map(download_trace_data, trace_infos) if t) @@ -532,7 +552,16 @@ def rename_experiment(self, experiment_id, new_name): self.store.rename_experiment(experiment_id, new_name) def log_metric( - self, run_id, key, value, timestamp=None, step=None, synchronous=True + self, + run_id, + key, + value, + timestamp=None, + step=None, + synchronous=True, + dataset_name: Optional[str] = None, + dataset_digest: Optional[str] = None, + model_id: Optional[str] = None, ) -> Optional[RunOperations]: """Log a metric against the run ID. @@ -559,7 +588,15 @@ def log_metric( timestamp = timestamp if timestamp is not None else get_current_time_millis() step = step if step is not None else 0 metric_value = convert_metric_value_to_float_if_possible(value) - metric = Metric(key, metric_value, timestamp, step) + metric = Metric( + key, + metric_value, + timestamp, + step, + model_id=model_id, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + ) if synchronous: self.store.log_metric(run_id, metric) else: @@ -698,10 +735,14 @@ def log_batch( metrics = [ Metric( - metric.key, - convert_metric_value_to_float_if_possible(metric.value), - metric.timestamp, - metric.step, + key=metric.key, + value=convert_metric_value_to_float_if_possible(metric.value), + timestamp=metric.timestamp, + step=metric.step, + dataset_name=metric.dataset_name, + dataset_digest=metric.dataset_digest, + model_id=metric.model_id, + run_id=metric.run_id, ) for metric in metrics ] @@ -752,12 +793,18 @@ def log_batch( # Merge all the run operations into a single run operations object return get_combined_run_operations(run_operations_list) - def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None): + def log_inputs( + self, + run_id: str, + datasets: Optional[List[DatasetInput]] = None, + models: Optional[List[ModelInput]] = None, + ): """Log one or more dataset inputs to a run. Args: run_id: String ID of the run datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log. + models: List of :py:class:`mlflow.entities.ModelInput` instances to log. Raises: MlflowException: If any errors occur. @@ -765,10 +812,10 @@ def log_inputs(self, run_id: str, datasets: Optional[List[DatasetInput]] = None) Returns: None """ - if datasets is None or len(datasets) == 0: - return + self.store.log_inputs(run_id=run_id, datasets=datasets, models=models) - self.store.log_inputs(run_id=run_id, datasets=datasets) + def log_outputs(self, run_id: str, models: List[ModelOutput]): + self.store.log_outputs(run_id=run_id, models=models) def _record_logged_model(self, run_id, mlflow_model): from mlflow.models import Model @@ -976,3 +1023,64 @@ def search_runs( order_by=order_by, page_token=page_token, ) + + def create_logged_model( + self, + experiment_id: str, + name: str, + run_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + model_type: Optional[str] = None, + ) -> LoggedModel: + return self.store.create_logged_model( + experiment_id=experiment_id, + name=name, + run_id=run_id, + tags=[ModelTag(str(key), str(value)) for key, value in tags.items()] + if tags is not None + else tags, + params=[ModelParam(str(key), str(value)) for key, value in params.items()] + if params is not None + else params, + model_type=model_type, + ) + + def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel: + return self.store.finalize_logged_model(model_id, status) + + def get_logged_model(self, model_id: str) -> LoggedModel: + return self.store.get_logged_model(model_id) + + def set_logged_model_tag(self, model_id: str, key: str, value: str): + return self.store.set_logged_model_tag(model_id, ModelTag(key, value)) + + def log_model_artifacts(self, model_id: str, local_dir: str) -> None: + self._get_artifact_repo_for_logged_model(model_id).log_artifacts(local_dir) + + def search_logged_models( + self, + experiment_ids: List[str], + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + ): + return self.store.search_logged_models(experiment_ids, filter_string, max_results, order_by) + + def _get_artifact_repo_for_logged_model(self, model_id: str) -> ArtifactRepository: + # Attempt to fetch the artifact repo from a local cache + cached_repo = utils._artifact_repos_cache.get(model_id) + if cached_repo is not None: + return cached_repo + else: + model = self.get_logged_model(model_id) + artifact_uri = add_databricks_profile_info_to_artifact_uri( + model.artifact_location, self.tracking_uri + ) + artifact_repo = get_artifact_repository(artifact_uri) + # Cache the artifact repo to avoid a future network call, removing the oldest + # entry in the cache if there are too many elements + if len(utils._artifact_repos_cache) > 1024: + utils._artifact_repos_cache.popitem(last=False) + utils._artifact_repos_cache[model_id] = artifact_repo + return artifact_repo diff --git a/mlflow/tracking/client.py b/mlflow/tracking/client.py index f85e93755e296..60394810cb44f 100644 --- a/mlflow/tracking/client.py +++ b/mlflow/tracking/client.py @@ -23,7 +23,11 @@ DatasetInput, Experiment, FileInfo, + LoggedModel, Metric, + ModelInput, + ModelOutput, + ModelStatus, Param, Run, RunTag, @@ -485,6 +489,7 @@ def search_traces( max_results: int = SEARCH_TRACES_DEFAULT_MAX_RESULTS, order_by: Optional[List[str]] = None, page_token: Optional[str] = None, + model_id: Optional[str] = None, ) -> PagedList[Trace]: """ Return traces that match the given list of search expressions within the experiments. @@ -511,6 +516,7 @@ def search_traces( max_results=max_results, order_by=order_by, page_token=page_token, + model_id=model_id, ) get_display_handler().display_traces(traces) @@ -1440,6 +1446,9 @@ def log_metric( timestamp: Optional[int] = None, step: Optional[int] = None, synchronous: Optional[bool] = None, + dataset_name: Optional[str] = None, + dataset_digest: Optional[str] = None, + model_id: Optional[str] = None, ) -> Optional[RunOperations]: """ Log a metric against the run ID. @@ -1515,7 +1524,15 @@ def print_run_info(r): synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() ) return self._tracking_client.log_metric( - run_id, key, value, timestamp, step, synchronous=synchronous + run_id, + key, + value, + timestamp, + step, + synchronous=synchronous, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + model_id=model_id, ) def log_param( @@ -1860,6 +1877,7 @@ def log_inputs( self, run_id: str, datasets: Optional[Sequence[DatasetInput]] = None, + models: Optional[Sequence[ModelInput]] = None, ) -> None: """ Log one or more dataset inputs to a run. @@ -1867,11 +1885,15 @@ def log_inputs( Args: run_id: String ID of the run. datasets: List of :py:class:`mlflow.entities.DatasetInput` instances to log. + models: List of :py:class:`mlflow.entities.ModelInput` instances to log. Raises: mlflow.MlflowException: If any errors occur. """ - self._tracking_client.log_inputs(run_id, datasets) + self._tracking_client.log_inputs(run_id, datasets, models) + + def log_outputs(self, run_id: str, models: Sequence[ModelOutput]): + self._tracking_client.log_outputs(run_id, models) def log_artifact(self, run_id, local_path, artifact_path=None) -> None: """Write a local file or directory to the remote ``artifact_uri``. @@ -3585,6 +3607,7 @@ def _create_model_version( description: Optional[str] = None, await_creation_for: int = DEFAULT_AWAIT_MAX_SLEEP_SECONDS, local_model_path: Optional[str] = None, + model_id: Optional[str] = None, ) -> ModelVersion: tracking_uri = self._tracking_client.tracking_uri if ( @@ -3627,6 +3650,7 @@ def _create_model_version( description=description, await_creation_for=await_creation_for, local_model_path=local_model_path, + model_id=model_id, ) def create_model_version( @@ -4735,3 +4759,39 @@ def print_model_version_info(mv): """ _validate_model_name(name) return self._get_registry_client().get_model_version_by_alias(name, alias) + + def create_logged_model( + self, + experiment_id: str, + name: str, + run_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + model_type: Optional[str] = None, + ) -> LoggedModel: + return self._tracking_client.create_logged_model( + experiment_id, name, run_id, tags, params, model_type + ) + + def finalize_logged_model(self, model_id: str, status: ModelStatus) -> LoggedModel: + return self._tracking_client.finalize_logged_model(model_id, status) + + def get_logged_model(self, model_id: str) -> LoggedModel: + return self._tracking_client.get_logged_model(model_id) + + def set_logged_model_tag(self, model_id: str, key: str, value: str): + return self._tracking_client.set_logged_model_tag(model_id, key, value) + + def log_model_artifacts(self, model_id: str, local_dir: str) -> None: + return self._tracking_client.log_model_artifacts(model_id, local_dir) + + def search_logged_models( + self, + experiment_ids: List[str], + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + ): + return self._tracking_client.search_logged_models( + experiment_ids, filter_string, max_results, order_by + ) diff --git a/mlflow/tracking/fluent.py b/mlflow/tracking/fluent.py index 257627f316b62..4e4d3f5fc10a7 100644 --- a/mlflow/tracking/fluent.py +++ b/mlflow/tracking/fluent.py @@ -17,7 +17,9 @@ DatasetInput, Experiment, InputTag, + LoggedModel, Metric, + ModelInput, Param, Run, RunStatus, @@ -825,6 +827,8 @@ def log_metric( synchronous: Optional[bool] = None, timestamp: Optional[int] = None, run_id: Optional[str] = None, + model_id: Optional[str] = None, + dataset: Optional[Dataset] = None, ) -> Optional[RunOperations]: """ Log a metric under the current run. If no run is active, this method will create @@ -868,14 +872,77 @@ def log_metric( """ run_id = run_id or _get_or_start_run().info.run_id synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() - return MlflowClient().log_metric( + _log_inputs_for_metrics_if_necessary( run_id, - key, - value, - timestamp or get_current_time_millis(), - step or 0, - synchronous=synchronous, + [ + Metric( + key=key, + value=value, + timestamp=timestamp or get_current_time_millis(), + step=step or 0, + model_id=model_id, + dataset_name=dataset.name if dataset is not None else None, + dataset_digest=dataset.digest if dataset is not None else None, + ), + ], + datasets=[dataset] if dataset is not None else None, + ) + timestamp = timestamp or get_current_time_millis() + step = step or 0 + model_ids = ( + [model_id] + if model_id is not None + else (_get_model_ids_for_new_metric_if_exist(run_id, step) or [None]) ) + for model_id in model_ids: + return MlflowClient().log_metric( + run_id, + key, + value, + timestamp, + step, + synchronous=synchronous, + model_id=model_id, + dataset_name=dataset.name if dataset is not None else None, + dataset_digest=dataset.digest if dataset is not None else None, + ) + + +def _log_inputs_for_metrics_if_necessary( + run_id, metrics: List[Metric], datasets: Optional[List[Dataset]] = None +) -> None: + client = MlflowClient() + run = client.get_run(run_id) + datasets = datasets or [] + for metric in metrics: + if metric.model_id is not None and metric.model_id not in [ + inp.model_id for inp in run.inputs.model_inputs + ] + [output.model_id for output in run.outputs.model_outputs]: + client.log_inputs(run_id, models=[ModelInput(model_id=metric.model_id)]) + if (metric.dataset_name, metric.dataset_digest) not in [ + (inp.dataset.name, inp.dataset.digest) for inp in run.inputs.dataset_inputs + ]: + matching_dataset = next( + ( + dataset + for dataset in datasets + if dataset.name == metric.dataset_name + and dataset.digest == metric.dataset_digest + ), + None, + ) + if matching_dataset is not None: + client.log_inputs( + run_id, + datasets=[DatasetInput(matching_dataset._to_mlflow_entity(), tags=[])], + ) + + +def _get_model_ids_for_new_metric_if_exist(run_id: str, metric_step: str) -> List[str]: + client = MlflowClient() + run = client.get_run(run_id) + model_outputs_at_step = [mo for mo in run.outputs.model_outputs if mo.step == metric_step] + return [mo.model_id for mo in model_outputs_at_step] def log_metrics( @@ -884,6 +951,8 @@ def log_metrics( synchronous: Optional[bool] = None, run_id: Optional[str] = None, timestamp: Optional[int] = None, + model_id: Optional[str] = None, + dataset: Optional[Dataset] = None, ) -> Optional[RunOperations]: """ Log multiple metrics for the current run. If no run is active, this method will create a new @@ -925,10 +994,37 @@ def log_metrics( """ run_id = run_id or _get_or_start_run().info.run_id timestamp = timestamp or get_current_time_millis() - metrics_arr = [Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + step = step or 0 + dataset_name = dataset.name if dataset is not None else None + dataset_digest = dataset.digest if dataset is not None else None + model_ids = ( + [model_id] + if model_id is not None + else (_get_model_ids_for_new_metric_if_exist(run_id, step) or [None]) + ) + metrics_arr = [ + Metric( + key, + value, + timestamp, + step or 0, + model_id=model_id, + dataset_name=dataset_name, + dataset_digest=dataset_digest, + ) + for key, value in metrics.items() + for model_id in model_ids + ] + _log_inputs_for_metrics_if_necessary( + run_id, metrics_arr, [dataset] if dataset is not None else None + ) synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() return MlflowClient().log_batch( - run_id=run_id, metrics=metrics_arr, params=[], tags=[], synchronous=synchronous + run_id=run_id, + metrics=metrics_arr, + params=[], + tags=[], + synchronous=synchronous, ) @@ -1815,6 +1911,60 @@ def delete_experiment(experiment_id: str) -> None: MlflowClient().delete_experiment(experiment_id) +def create_logged_model( + name: str, + run_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + model_type: Optional[str] = None, + experiment_id: Optional[str] = None, +) -> LoggedModel: + run = active_run() + if run_id is None and run is not None: + run_id = run.info.run_id + experiment_id = experiment_id if experiment_id is not None else _get_experiment_id() + return MlflowClient().create_logged_model( + experiment_id=experiment_id, + name=name, + run_id=run_id, + tags=tags, + params=params, + model_type=model_type, + ) + + +def get_logged_model(model_id: str) -> LoggedModel: + return MlflowClient().get_logged_model(model_id) + + +def search_logged_models( + experiment_ids: Optional[List[str]] = None, + filter_string: Optional[str] = None, + max_results: Optional[int] = None, + order_by: Optional[List[str]] = None, + output_format: str = "pandas", +) -> Union[List[LoggedModel], "pandas.DataFrame"]: + experiment_ids = experiment_ids or [_get_experiment_id()] + models = MlflowClient().search_logged_models( + experiment_ids=experiment_ids, + filter_string=filter_string, + max_results=max_results, + order_by=order_by, + ) + if output_format == "pandas": + import pandas as pd + + return pd.DataFrame([model.to_dictionary() for model in models]) + elif output_format == "list": + return models + else: + raise MlflowException( + "Unsupported output format: %s. Supported string values are 'pandas' or 'list'" + % output_format, + INVALID_PARAMETER_VALUE, + ) + + def delete_run(run_id: str) -> None: """ Deletes a run with the given ID. diff --git a/mlflow/utils/search_utils.py b/mlflow/utils/search_utils.py index aa399f93f0e99..0a983421cfc01 100644 --- a/mlflow/utils/search_utils.py +++ b/mlflow/utils/search_utils.py @@ -5,7 +5,7 @@ import operator import re import shlex -from typing import Any, Dict +from typing import Any, Dict, List, Optional import sqlparse from packaging.version import Version @@ -20,7 +20,7 @@ ) from sqlparse.tokens import Token as TokenType -from mlflow.entities import RunInfo +from mlflow.entities import LoggedModel, RunInfo from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE @@ -635,6 +635,38 @@ def _does_run_match_clause(cls, run, sed): return SearchUtils.get_comparison_func(comparator)(lhs, value) + @classmethod + def _does_model_match_clause(cls, model, sed): + key_type = sed.get("type") + key = sed.get("key") + value = sed.get("value") + comparator = sed.get("comparator").upper() + + key = SearchUtils.translate_key_alias(key) + + if cls.is_metric(key_type, comparator): + matching_metrics = [metric for metric in model.metrics if metric.key == key] + lhs = matching_metrics[0].value if matching_metrics else None + value = float(value) + elif cls.is_param(key_type, comparator): + lhs = model.params.get(key, None) + elif cls.is_tag(key_type, comparator): + lhs = model.tags.get(key, None) + elif cls.is_string_attribute(key_type, key, comparator): + lhs = getattr(model.info, key) + elif cls.is_numeric_attribute(key_type, key, comparator): + lhs = getattr(model.info, key) + value = int(value) + else: + raise MlflowException( + f"Invalid model search expression type '{key_type}'", + error_code=INVALID_PARAMETER_VALUE, + ) + if lhs is None: + return False + + return SearchUtils.get_comparison_func(comparator)(lhs, value) + @classmethod def filter(cls, runs, filter_string): """Filters a set of runs based on a search filter string.""" @@ -647,6 +679,20 @@ def run_matches(run): return [run for run in runs if run_matches(run)] + @classmethod + def filter_logged_models(cls, models: List[LoggedModel], filter_string: Optional[str] = None): + """Filters a set of runs based on a search filter string.""" + if not filter_string: + return models + + # TODO: Update parsing function to handle model-specific filter clauses + parsed = cls.parse_search_filter(filter_string) + + def model_matches(model): + return all(cls._does_model_match_clause(model, s) for s in parsed) + + return [model for model in models if model_matches(model)] + @classmethod def _validate_order_by_and_generate_token(cls, order_by): try: @@ -760,6 +806,40 @@ def _get_value_for_sort(cls, run, key_type, key, ascending): return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value) + @classmethod + def _get_model_value_for_sort(cls, model, key_type, key, ascending): + """Returns a tuple suitable to be used as a sort key for models.""" + sort_value = None + key = SearchUtils.translate_key_alias(key) + if key_type == cls._METRIC_IDENTIFIER: + matching_metrics = [metric for metric in model.metrics if metric.key == key] + sort_value = float(matching_metrics[0].value) if matching_metrics else None + elif key_type == cls._PARAM_IDENTIFIER: + sort_value = model.params.get(key) + elif key_type == cls._TAG_IDENTIFIER: + sort_value = model.tags.get(key) + elif key_type == cls._ATTRIBUTE_IDENTIFIER: + sort_value = getattr(model, key) + else: + raise MlflowException( + f"Invalid models order_by entity type '{key_type}'", + error_code=INVALID_PARAMETER_VALUE, + ) + + # Return a key such that None values are always at the end. + is_none = sort_value is None + is_nan = isinstance(sort_value, float) and math.isnan(sort_value) + fill_value = (1 if ascending else -1) * math.inf + + if is_none: + sort_value = fill_value + elif is_nan: + sort_value = -fill_value + + is_none_or_nan = is_none or is_nan + + return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value) + @classmethod def sort(cls, runs, order_by_list): """Sorts a set of runs based on their natural ordering and an overriding set of order_bys. @@ -780,6 +860,24 @@ def sort(cls, runs, order_by_list): ) return runs + @classmethod + def sort_logged_models(cls, models, order_by_list): + models = sorted(models, key=lambda model: (-model.creation_timestamp, model.model_id)) + if not order_by_list: + return models + # NB: We rely on the stability of Python's sort function, so that we can apply + # the ordering conditions in reverse order. + for order_by_clause in reversed(order_by_list): + # TODO: Update parsing function to handle model-specific order-by keys + (key_type, key, ascending) = cls.parse_order_by_for_search_runs(order_by_clause) + + models = sorted( + models, + key=lambda model: cls._get_model_value_for_sort(model, key_type, key, ascending), + reverse=not ascending, + ) + return models + @classmethod def parse_start_offset_from_page_token(cls, page_token): # Note: the page_token is expected to be a base64-encoded JSON that looks like