From a3d1f16c85e706ff6a0367f8799010365f7f2625 Mon Sep 17 00:00:00 2001 From: Phoevos Kalemkeris Date: Tue, 2 Dec 2025 11:14:11 +0000 Subject: [PATCH 1/2] mlflow: Add version tags for registered models Add the following model version tags when logging a model to MLflow: * model_uri: The URI of the model artifact * model_type: The type of the model (e.g. 'medcat_snomed') * validation_status: The validation status of the model (e.g. 'pending') Signed-off-by: Phoevos Kalemkeris --- app/cli/cli.py | 1 + app/management/tracker_client.py | 51 ++++++++++++++++++--- app/trainers/huggingface_llm_trainer.py | 1 + app/trainers/huggingface_ner_trainer.py | 2 + app/trainers/medcat_deid_trainer.py | 7 ++- app/trainers/medcat_trainer.py | 14 +++++- app/trainers/metacat_trainer.py | 7 ++- tests/app/monitoring/test_tracker_client.py | 36 +++++++++++++-- 8 files changed, 105 insertions(+), 14 deletions(-) diff --git a/app/cli/cli.py b/app/cli/cli.py index 4358338..612a8dc 100644 --- a/app/cli/cli.py +++ b/app/cli/cli.py @@ -304,6 +304,7 @@ def register_model( model_config=m_config, model_metrics=m_metrics, model_tags=m_tags, + model_type=model_type.value, ) typer.echo(f"Pushed {model_path} as a new model version ({run_name})") diff --git a/app/management/tracker_client.py b/app/management/tracker_client.py index 61686a2..0c43239 100644 --- a/app/management/tracker_client.py +++ b/app/management/tracker_client.py @@ -346,6 +346,39 @@ def log_model_config(config: Dict[str, str]) -> None: mlflow.log_params(config) + @staticmethod + def _set_model_version_tags( + client: MlflowClient, + model_name: str, + version: str, + model_type: Optional[str] = None, + validation_status: Optional[str] = None, + ) -> None: + """ + Sets standard tags on a model version for serving and discovery. + + Args: + client (MlflowClient): The MLflow client to use for setting tags. + model_name (str): The name of the registered model. + version (str): The version of the model. + model_type (Optional[str]): The type of the model (e.g., "medcat_snomed"). + validation_status (Optional[str]): The status of the model validation (e.g., "pending"). + """ + try: + client.set_model_version_tag( + name=model_name, version=version, key="model_uri", value=f"models:/{model_name}/{version}" + ) + if model_type is not None: + client.set_model_version_tag( + name=model_name, version=version, key="model_type", value=model_type + ) + if validation_status is not None: + client.set_model_version_tag( + name=model_name, version=version, key="validation_status", value=validation_status + ) + except Exception: + logger.warning("Failed to set tags on version %s of model %s", version, model_name) + @staticmethod def log_model( model_name: str, @@ -386,6 +419,7 @@ def save_pretrained_model( model_config: Optional[Dict] = None, model_metrics: Optional[List[Dict]] = None, model_tags: Optional[Dict] = None, + model_type: Optional[str] = None, ) -> None: """ Saves a pretrained model to the tracking backend and associated metadata. @@ -399,6 +433,7 @@ def save_pretrained_model( model_config (Optional[Dict]): The configuration of the model to save. model_metrics (Optional[List[Dict]]): The list of dictionaries containing model metrics to save. model_tags (Optional[Dict]): The dictionary of tags to set for the model. + model_type (Optional[str]): The type of the model (e.g., "medcat_snomed"). """ experiment_name = TrackerClient.get_experiment_name(model_name, training_type) @@ -423,6 +458,10 @@ def save_pretrained_model( mlflow.set_tags(tags) model_name = model_name.replace(" ", "_") TrackerClient.log_model(model_name, model_path, model_manager, model_name) + client = MlflowClient() + versions = client.search_model_versions(f"name='{model_name}'") + if versions: + TrackerClient._set_model_version_tags(client, model_name, versions[0].version, model_type) TrackerClient.end_with_success() except KeyboardInterrupt: TrackerClient.end_with_interruption() @@ -503,6 +542,7 @@ def save_model( model_name: str, model_manager: ModelManager, validation_status: str = "pending", + model_type: Optional[str] = None, ) -> str: """ Saves a model and its information to the tracking backend. @@ -512,6 +552,7 @@ def save_model( model_name (str): The name of the model. model_manager (ModelManager): The instance of ModelManager used for model saving. validation_status (str): The status of the model validation (default: "pending"). + model_type (Optional[str]): The type of the model (e.g., "medcat_snomed"). Returns: str: The artifact URI of the saved model. @@ -524,12 +565,10 @@ def save_model( if not mlflow.get_tracking_uri().startswith("file:/"): TrackerClient.log_model(model_name, filepath, model_manager, model_name) versions = self.mlflow_client.search_model_versions(f"name='{model_name}'") - self.mlflow_client.set_model_version_tag( - name=model_name, - version=versions[0].version, - key="validation_status", - value=validation_status, - ) + if versions: + TrackerClient._set_model_version_tags( + self.mlflow_client, model_name, versions[0].version, model_type, validation_status + ) else: TrackerClient.log_model(model_name, filepath, model_manager) diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index b85f44b..0afd2a2 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -436,6 +436,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + model_type=self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: diff --git a/app/trainers/huggingface_ner_trainer.py b/app/trainers/huggingface_ner_trainer.py index c975506..5e485ad 100644 --- a/app/trainers/huggingface_ner_trainer.py +++ b/app/trainers/huggingface_ner_trainer.py @@ -237,6 +237,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + model_type=self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: @@ -664,6 +665,7 @@ def run( retrained_model_pack_path, self._model_name, self._model_manager, + model_type=self._model_service.info().model_type.value, ) logger.info(f"Retrained model saved: {model_uri}") else: diff --git a/app/trainers/medcat_deid_trainer.py b/app/trainers/medcat_deid_trainer.py index 65ac7be..5ee6e55 100644 --- a/app/trainers/medcat_deid_trainer.py +++ b/app/trainers/medcat_deid_trainer.py @@ -185,7 +185,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/app/trainers/medcat_trainer.py b/app/trainers/medcat_trainer.py index e49068f..ec1ff28 100644 --- a/app/trainers/medcat_trainer.py +++ b/app/trainers/medcat_trainer.py @@ -211,7 +211,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: @@ -472,7 +477,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info(f"Retrained model saved: {model_uri}") self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/app/trainers/metacat_trainer.py b/app/trainers/metacat_trainer.py index 5cce3b9..60992a2 100644 --- a/app/trainers/metacat_trainer.py +++ b/app/trainers/metacat_trainer.py @@ -159,7 +159,12 @@ def run( ) with open(cdb_config_path, "w") as f: json.dump(dump_pydantic_object_to_dict(model.config), f) - model_uri = self._tracker_client.save_model(model_pack_path, self._model_name, self._model_manager) + model_uri = self._tracker_client.save_model( + model_pack_path, + self._model_name, + self._model_manager, + model_type=self._model_service.info().model_type.value, + ) logger.info("Retrained model saved: %s", model_uri) self._tracker_client.save_model_artifact(cdb_config_path, self._model_name) else: diff --git a/tests/app/monitoring/test_tracker_client.py b/tests/app/monitoring/test_tracker_client.py index edcbaee..ed06a2a 100644 --- a/tests/app/monitoring/test_tracker_client.py +++ b/tests/app/monitoring/test_tracker_client.py @@ -3,7 +3,7 @@ import datasets import pytest import pandas as pd -from unittest.mock import Mock, call, ANY +from unittest.mock import Mock, call, patch, ANY from app.management.tracker_client import TrackerClient from app.data import doc_dataset from app.domain import TrainerBackend @@ -161,11 +161,23 @@ def test_save_model(mlflow_fixture): mlflow_client.search_model_versions.return_value = [version] tracker_client.mlflow_client = mlflow_client - artifact_uri = tracker_client.save_model("path/to/file.zip", "model_name", model_manager, "validation_status") + artifact_uri = tracker_client.save_model( + "path/to/file.zip", "model_name", model_manager, "validation_status", "model_type" + ) assert "artifacts/model_name" in artifact_uri model_manager.log_model.assert_called_once_with("model_name", "path/to/file.zip", "model_name") - mlflow_client.set_model_version_tag.assert_called_once_with(name="model_name", version="1", key="validation_status", value="validation_status") + mlflow_client.search_model_versions.assert_called_once_with("name='model_name'") + assert mlflow_client.set_model_version_tag.call_count == 3 + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_uri", value="models:/model_name/1" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_type", value="model_type" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="validation_status", value="validation_status" + ) mlflow.set_tag.has_calls( [ call("training.output.package", "file.zip"), @@ -184,9 +196,15 @@ def test_save_model_local(mlflow_fixture): model_manager.save_model.assert_called_once_with("local_dir", "filepath") -def test_save_pretrained_model(mlflow_fixture): +@patch("app.management.tracker_client.MlflowClient") +def test_save_pretrained_model(mock_mlflow_client_class, mlflow_fixture): tracker_client = TrackerClient("") model_manager = Mock() + mlflow_client = Mock() + version = Mock() + version.version = "1" + mlflow_client.search_model_versions.return_value = [version] + mock_mlflow_client_class.return_value = mlflow_client tracker_client.save_pretrained_model( "model_name", @@ -197,6 +215,7 @@ def test_save_pretrained_model(mlflow_fixture): {"param": "value"}, [{"p": 0.8, "r": 0.8}, {"p": 0.9, "r": 0.9}], {"tag_name": "tag_value"}, + "model_type", ) mlflow.get_experiment_by_name.assert_called_once_with("model_name_training_type") @@ -212,6 +231,15 @@ def test_save_pretrained_model(mlflow_fixture): assert len(mlflow.set_tags.call_args.args[0]["mlflow.source.name"]) > 0 assert mlflow.set_tags.call_args.args[0]["tag_name"] == "tag_value" + mlflow_client.search_model_versions.assert_called_once_with("name='model_name'") + assert mlflow_client.set_model_version_tag.call_count == 2 + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_uri", value="models:/model_name/1" + ) + mlflow_client.set_model_version_tag.assert_any_call( + name="model_name", version="1", key="model_type", value="model_type" + ) + def test_log_single_exception(mlflow_fixture): tracker_client = TrackerClient("") From 6aeaecb514079ac7c952852271318ae72d9c8408 Mon Sep 17 00:00:00 2001 From: Phoevos Kalemkeris Date: Tue, 2 Dec 2025 11:36:41 +0000 Subject: [PATCH 2/2] mlflow: Add model type to training run tags Signed-off-by: Phoevos Kalemkeris --- app/management/tracker_client.py | 1 + tests/app/monitoring/test_tracker_client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/app/management/tracker_client.py b/app/management/tracker_client.py index 0c43239..5ab7dfc 100644 --- a/app/management/tracker_client.py +++ b/app/management/tracker_client.py @@ -574,6 +574,7 @@ def save_model( artifact_uri = mlflow.get_artifact_uri(model_name) mlflow.set_tag("training.output.model_uri", artifact_uri) + mlflow.set_tag("training.output.model_type", model_type) return artifact_uri diff --git a/tests/app/monitoring/test_tracker_client.py b/tests/app/monitoring/test_tracker_client.py index ed06a2a..46d4697 100644 --- a/tests/app/monitoring/test_tracker_client.py +++ b/tests/app/monitoring/test_tracker_client.py @@ -182,6 +182,7 @@ def test_save_model(mlflow_fixture): [ call("training.output.package", "file.zip"), call("training.output.model_uri", artifact_uri), + call("training.output.model_type", "model_type"), ], any_order=False, )