Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions undeepvo/utils/mflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@
import mlflow
import mlflow.exceptions

DEFAULT_USER_NAME = ""
DEFAULT_PASSWORD = ""
DEFAULT_DATABRICKS_HOST = ""
DEFAULT_HOST_URI = "http://329801-ilinvalery.tmweb.ru:5001/"
DEFAULT_EXPERIMENT_NAME = "undeepvo"
CREATE_DATABRICKS_CREDENTIALS = False

os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://329801-ilinvalery.tmweb.ru:9000"
os.environ["AWS_ACCESS_KEY_ID"] = "123"
os.environ["AWS_SECRET_ACCESS_KEY"] = "12345678"


class MlFlowHandler(object):
def __init__(self, experiment_name=DEFAULT_EXPERIMENT_NAME, user_name=DEFAULT_USER_NAME, password=DEFAULT_PASSWORD,
host_uri=DEFAULT_HOST_URI, create_databricks_credential=CREATE_DATABRICKS_CREDENTIALS,
databricks_host=DEFAULT_DATABRICKS_HOST, mlflow_tags={}, mlflow_parameters={}):
self._user_name = DEFAULT_USER_NAME
self._password = DEFAULT_PASSWORD
if host_uri == "databricks" and create_databricks_credential:
self._create_databricks_credential(user_name, password, databricks_host)
def __init__(self,
experiment_name="",
host_uri="",
databricks_config=None,
artifact_aws_config=None,
mlflow_tags=None,
mlflow_parameters=None):
if databricks_config is not None:
self._create_databricks_credential(databricks_config["username"], databricks_config["password"],
databricks_config["databricks_host"])
if artifact_aws_config is not None:
os.environ["MLFLOW_S3_ENDPOINT_URL"] = artifact_aws_config["endpoint"]
os.environ["AWS_ACCESS_KEY_ID"] = artifact_aws_config["username"]
os.environ["AWS_SECRET_ACCESS_KEY"] = artifact_aws_config["password"]
if mlflow_tags is None:
mlflow_tags = {}
if mlflow_parameters is None:
mlflow_parameters = {}
mlflow.set_tracking_uri(host_uri)
self._experiment_name = experiment_name
self._mlflow_client = mlflow.tracking.MlflowClient(host_uri)
Expand Down
15 changes: 2 additions & 13 deletions undeepvo/utils/training_process_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,10 @@
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from .mflow_handler import MlFlowHandler


class TrainingProcessHandler(object):
def __init__(self, data_folder="logs", model_folder="model", enable_iteration_progress_bar=False,
model_save_key="loss", mlflow_tags=None, mlflow_parameters=None, enable_mlflow=True,
mlflow_experiment_name="undeepvo"):
if mlflow_tags is None:
mlflow_tags = {}
if mlflow_parameters is None:
mlflow_parameters = {}
model_save_key="loss", mlflow_handler=None):
self._name = None
self._epoch_count = 0
self._iteration_count = 0
Expand All @@ -41,11 +34,7 @@ def __init__(self, data_folder="logs", model_folder="model", enable_iteration_pr
self._audio_configs = {}
self._global_epoch_step = 0
self._global_iteration_step = 0
if enable_mlflow:
self._mlflow_handler = MlFlowHandler(experiment_name=mlflow_experiment_name,
mlflow_tags=mlflow_tags, mlflow_parameters=mlflow_parameters)
else:
self._mlflow_handler = None
self._mlflow_handler = mlflow_handler
self._artifacts = []

def setup_handler(self, name, model):
Expand Down