diff --git a/app/cli/cli.py b/app/cli/cli.py index 4358338..612a8dc 100644 --- a/app/cli/cli.py +++ b/app/cli/cli.py @@ -304,6 +304,7 @@ def register_model( model_config=m_config, model_metrics=m_metrics, model_tags=m_tags, + model_type=model_type.value, ) typer.echo(f"Pushed {model_path} as a new model version ({run_name})") diff --git a/app/management/tracker_client.py b/app/management/tracker_client.py index 61686a2..5ab7dfc 100644 --- a/app/management/tracker_client.py +++ b/app/management/tracker_client.py @@ -346,6 +346,39 @@ def log_model_config(config: Dict[str, str]) -> None: mlflow.log_params(config) + @staticmethod + def _set_model_version_tags( + client: MlflowClient, + model_name: str, + version: str, + model_type: Optional[str] = None, + validation_status: Optional[str] = None, + ) -> None: + """ + Sets standard tags on a model version for serving and discovery. + + Args: + client (MlflowClient): The MLflow client to use for setting tags. + model_name (str): The name of the registered model. + version (str): The version of the model. + model_type (Optional[str]): The type of the model (e.g., "medcat_snomed"). + validation_status (Optional[str]): The status of the model validation (e.g., "pending"). + """ + try: + client.set_model_version_tag( + name=model_name, version=version, key="model_uri", value=f"models:/{model_name}/{version}" + ) + if model_type is not None: + client.set_model_version_tag( + name=model_name, version=version, key="model_type", value=model_type + ) + if validation_status is not None: + client.set_model_version_tag( + name=model_name, version=version, key="validation_status", value=validation_status + ) + except Exception: + logger.warning("Failed to set tags on version %s of model %s", version, model_name) + @staticmethod def log_model( model_name: str, @@ -386,6 +419,7 @@ def save_pretrained_model( model_config: Optional[Dict] = None, model_metrics: Optional[List[Dict]] = None, model_tags: Optional[Dict] = None, + model_type: Optional[str] = None, ) -> None: """ Saves a pretrained model to the tracking backend and associated metadata. @@ -399,6 +433,7 @@ def save_pretrained_model( model_config (Optional[Dict]): The configuration of the model to save. model_metrics (Optional[List[Dict]]): The list of dictionaries containing model metrics to save. model_tags (Optional[Dict]): The dictionary of tags to set for the model. + model_type (Optional[str]): The type of the model (e.g., "medcat_snomed"). """ experiment_name = TrackerClient.get_experiment_name(model_name, training_type) @@ -423,6 +458,10 @@ def save_pretrained_model( mlflow.set_tags(tags) model_name = model_name.replace(" ", "_") TrackerClient.log_model(model_name, model_path, model_manager, model_name) + client = MlflowClient() + versions = client.search_model_versions(f"name='{model_name}'") + if versions: + TrackerClient._set_model_version_tags(client, model_name, versions[0].version, model_type) TrackerClient.end_with_success() except KeyboardInterrupt: TrackerClient.end_with_interruption() @@ -503,6 +542,7 @@ def save_model( model_name: str, model_manager: ModelManager, validation_status: str = "pending", + model_type: Optional[str] = None, ) -> str: """ Saves a model and its information to the tracking backend. @@ -512,6 +552,7 @@ def save_model( model_name (str): The name of the model. model_manager (ModelManager): The instance of ModelManager used for model saving. validation_status (str): The status of the model validation (default: "pending"). + model_type (Optional[str]): The type of the model (e.g., "medcat_snomed"). Returns: str: The artifact URI of the saved model. @@ -524,17 +565,16 @@ def save_model( if not mlflow.get_tracking_uri().startswith("file:/"): TrackerClient.log_model(model_name, filepath, model_manager, model_name) versions = self.mlflow_client.search_model_versions(f"name='{model_name}'") - self.mlflow_client.set_model_version_tag( - name=model_name, - version=versions[0].version, - key="validation_status", - value=validation_status, - ) + if versions: + TrackerClient._set_model_version_tags( + self.mlflow_client, model_name, versions[0].version, model_type, validation_status + ) else: TrackerClient.log_model(model_name, filepath, model_manager) artifact_uri = mlflow.get_artifact_uri(model_name) mlflow.set_tag("training.output.model_uri", artifact_uri) + mlflow.set_tag("training.output.model_type", model_type) return artifact_uri diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index b85f44b..0afd2a2 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -436,6 +436,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + model_type=self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: diff --git a/app/trainers/huggingface_ner_trainer.py b/app/trainers/huggingface_ner_trainer.py index c975506..5e485ad 100644 --- a/app/trainers/huggingface_ner_trainer.py +++ b/app/trainers/huggingface_ner_trainer.py @@ -237,6 +237,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + model_type=self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: @@ -664,6 +665,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + model_type=self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: diff --git a/app/trainers/medcat_deid_trainer.py b/app/trainers/medcat_deid_trainer.py index 65ac7be..5ee6e55 100644 --- a/app/trainers/medcat_deid_trainer.py +++ b/app/trainers/medcat_deid_trainer.py @@ -185,7 +185,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/app/trainers/medcat_trainer.py b/app/trainers/medcat_trainer.py index e49068f..ec1ff28 100644 --- a/app/trainers/medcat_trainer.py +++ b/app/trainers/medcat_trainer.py @@ -211,7 +211,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: @@ -472,7 +477,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info(f"Retrained model saved: {model_uri}") self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/app/trainers/metacat_trainer.py b/app/trainers/metacat_trainer.py index 5cce3b9..60992a2 100644 --- a/app/trainers/metacat_trainer.py +++ b/app/trainers/metacat_trainer.py @@ -159,7 +159,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/tests/app/monitoring/test_tracker_client.py b/tests/app/monitoring/test_tracker_client.py index edcbaee..46d4697 100644 --- a/tests/app/monitoring/test_tracker_client.py +++ b/tests/app/monitoring/test_tracker_client.py @@ -3,7 +3,7 @@ import datasets import pytest import pandas as pd -from unittest.mock import Mock, call, ANY +from unittest.mock import Mock, call, patch, ANY from app.management.tracker_client import TrackerClient from app.data import doc_dataset from app.domain import TrainerBackend @@ -161,15 +161,28 @@ def test_save_model(mlflow_fixture): mlflow_client.search_model_versions.return_value = [version] tracker_client.mlflow_client = mlflow_client - artifact_uri = tracker_client.save_model("path/to/file.zip", "model_name", model_manager, "validation_status") + artifact_uri = tracker_client.save_model( + "path/to/file.zip", "model_name", model_manager, "validation_status", "model_type" + ) assert "artifacts/model_name" in artifact_uri model_manager.log_model.assert_called_once_with("model_name", "path/to/file.zip", "model_name") - mlflow_client.set_model_version_tag.assert_called_once_with(name="model_name", version="1", key="validation_status", value="validation_status") + mlflow_client.search_model_versions.assert_called_once_with("name='model_name'") + assert mlflow_client.set_model_version_tag.call_count == 3 + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_uri", value="models:/model_name/1" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_type", value="model_type" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="validation_status", value="validation_status" + ) mlflow.set_tag.has_calls( [ call("training.output.package", "file.zip"), call("training.output.model_uri", artifact_uri), + call("training.output.model_type", "model_type"), ], any_order=False, ) @@ -184,9 +197,15 @@ def test_save_model_local(mlflow_fixture): model_manager.save_model.assert_called_once_with("local_dir", "filepath") -def test_save_pretrained_model(mlflow_fixture): +@patch("app.management.tracker_client.MlflowClient") +def test_save_pretrained_model(mock_mlflow_client_class, mlflow_fixture): tracker_client = TrackerClient("") model_manager = Mock() + mlflow_client = Mock() + version = Mock() + version.version = "1" + mlflow_client.search_model_versions.return_value = [version] + mock_mlflow_client_class.return_value = mlflow_client tracker_client.save_pretrained_model( "model_name", @@ -197,6 +216,7 @@ def test_save_pretrained_model(mlflow_fixture): {"param": "value"}, [{"p": 0.8, "r": 0.8}, {"p": 0.9, "r": 0.9}], {"tag_name": "tag_value"}, + "model_type", ) mlflow.get_experiment_by_name.assert_called_once_with("model_name_training_type") @@ -212,6 +232,15 @@ def test_save_pretrained_model(mlflow_fixture): assert len(mlflow.set_tags.call_args.args[0]["mlflow.source.name"]) > 0 assert mlflow.set_tags.call_args.args[0]["tag_name"] == "tag_value" + mlflow_client.search_model_versions.assert_called_once_with("name='model_name'") + assert mlflow_client.set_model_version_tag.call_count == 2 + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_uri", value="models:/model_name/1" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_type", value="model_type" + ) + def test_log_single_exception(mlflow_fixture): tracker_client = TrackerClient("")