diff --git a/undeepvo/utils/mflow_handler.py b/undeepvo/utils/mflow_handler.py index d1082e0..5c79b16 100644 --- a/undeepvo/utils/mflow_handler.py +++ b/undeepvo/utils/mflow_handler.py @@ -3,26 +3,26 @@ import mlflow import mlflow.exceptions -DEFAULT_USER_NAME = "" -DEFAULT_PASSWORD = "" -DEFAULT_DATABRICKS_HOST = "" -DEFAULT_HOST_URI = "http://329801-ilinvalery.tmweb.ru:5001/" -DEFAULT_EXPERIMENT_NAME = "undeepvo" -CREATE_DATABRICKS_CREDENTIALS = False - -os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://329801-ilinvalery.tmweb.ru:9000" -os.environ["AWS_ACCESS_KEY_ID"] = "123" -os.environ["AWS_SECRET_ACCESS_KEY"] = "12345678" - class MlFlowHandler(object): - def __init__(self, experiment_name=DEFAULT_EXPERIMENT_NAME, user_name=DEFAULT_USER_NAME, password=DEFAULT_PASSWORD, - host_uri=DEFAULT_HOST_URI, create_databricks_credential=CREATE_DATABRICKS_CREDENTIALS, - databricks_host=DEFAULT_DATABRICKS_HOST, mlflow_tags={}, mlflow_parameters={}): - self._user_name = DEFAULT_USER_NAME - self._password = DEFAULT_PASSWORD - if host_uri == "databricks" and create_databricks_credential: - self._create_databricks_credential(user_name, password, databricks_host) + def __init__(self, + experiment_name="", + host_uri="", + databricks_config=None, + artifact_aws_config=None, + mlflow_tags=None, + mlflow_parameters=None): + if databricks_config is not None: + self._create_databricks_credential(databricks_config["username"], databricks_config["password"], + databricks_config["databricks_host"]) + if artifact_aws_config is not None: + os.environ["MLFLOW_S3_ENDPOINT_URL"] = artifact_aws_config["endpoint"] + os.environ["AWS_ACCESS_KEY_ID"] = artifact_aws_config["username"] + os.environ["AWS_SECRET_ACCESS_KEY"] = artifact_aws_config["password"] + if mlflow_tags is None: + mlflow_tags = {} + if mlflow_parameters is None: + mlflow_parameters = {} mlflow.set_tracking_uri(host_uri) self._experiment_name = experiment_name self._mlflow_client = mlflow.tracking.MlflowClient(host_uri) diff --git a/undeepvo/utils/training_process_handler.py b/undeepvo/utils/training_process_handler.py index 56dcda0..3298e23 100644 --- a/undeepvo/utils/training_process_handler.py +++ b/undeepvo/utils/training_process_handler.py @@ -5,17 +5,10 @@ from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm -from .mflow_handler import MlFlowHandler - class TrainingProcessHandler(object): def __init__(self, data_folder="logs", model_folder="model", enable_iteration_progress_bar=False, - model_save_key="loss", mlflow_tags=None, mlflow_parameters=None, enable_mlflow=True, - mlflow_experiment_name="undeepvo"): - if mlflow_tags is None: - mlflow_tags = {} - if mlflow_parameters is None: - mlflow_parameters = {} + model_save_key="loss", mlflow_handler=None): self._name = None self._epoch_count = 0 self._iteration_count = 0 @@ -41,11 +34,7 @@ def __init__(self, data_folder="logs", model_folder="model", enable_iteration_pr self._audio_configs = {} self._global_epoch_step = 0 self._global_iteration_step = 0 - if enable_mlflow: - self._mlflow_handler = MlFlowHandler(experiment_name=mlflow_experiment_name, - mlflow_tags=mlflow_tags, mlflow_parameters=mlflow_parameters) - else: - self._mlflow_handler = None + self._mlflow_handler = mlflow_handler self._artifacts = [] def setup_handler(self, name, model):