diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh
index 42dcb4a9f..c427ee475 100644
--- a/cli/cli_tests_training.sh
+++ b/cli/cli_tests_training.sh
@@ -163,16 +163,6 @@ echo "AGG_UID=$AGG_UID" >> "$LAST_ENV_FILE"
echo "\n"
-##########################################################
-echo "====================================="
-echo "Running aggregator association step"
-echo "====================================="
-print_eval medperf aggregator associate -a $AGG_UID -t $TRAINING_UID -y
-checkFailed "aggregator association step failed"
-##########################################################
-
-echo "\n"
-
##########################################################
echo "====================================="
echo "Activate modelowner profile"
@@ -185,10 +175,10 @@ echo "\n"
##########################################################
echo "====================================="
-echo "Approve aggregator association"
+echo "Running set aggregator step"
echo "====================================="
-print_eval medperf association approve -t $TRAINING_UID -a $AGG_UID
-checkFailed "agg association approval failed"
+print_eval medperf training set_aggregator -t $TRAINING_UID -a $AGG_UID -y
+checkFailed "Setting aggregator failed"
##########################################################
echo "\n"
diff --git a/cli/medperf/_version.py b/cli/medperf/_version.py
index d3ec452c3..493f7415d 100644
--- a/cli/medperf/_version.py
+++ b/cli/medperf/_version.py
@@ -1 +1 @@
-__version__ = "0.2.0"
+__version__ = "0.3.0"
diff --git a/cli/medperf/asset_management/asset_management.py b/cli/medperf/asset_management/asset_management.py
index 749c1e200..5132cedc8 100644
--- a/cli/medperf/asset_management/asset_management.py
+++ b/cli/medperf/asset_management/asset_management.py
@@ -9,7 +9,7 @@
from medperf.asset_management.asset_storage_manager import AssetStorageManager
from medperf.asset_management.asset_policy_manager import AssetPolicyManager
from medperf.asset_management.cc_operator import OperatorManager
-from medperf.utils import tar, generate_tmp_path
+from medperf.utils import tar, generate_tmp_path, remove_path
import secrets
from medperf.exceptions import MedperfException
from medperf import config as medperf_config
@@ -42,12 +42,20 @@ def setup_dataset_for_cc(dataset: Dataset):
cc_policy = dataset.get_cc_policy()
__verify_cloud_environment(cc_config)
- # create dataset asset
+ # policy setup
+ medperf_config.ui.text = "Generating encryption key"
+ encryption_key = generate_encryption_key()
+ asset_policy_manager = AssetPolicyManager(cc_config)
+ asset_policy_manager.setup_policy(cc_policy, encryption_key)
+
+ # storage
medperf_config.ui.text = "Compressing dataset"
asset_path = generate_tmp_path()
tar(asset_path, [dataset.data_path, dataset.labels_path])
-
- __setup_asset_for_cc(cc_config, cc_policy, asset_path)
+ asset_storage_manager = AssetStorageManager(cc_config, asset_path, encryption_key)
+ asset_storage_manager.store_asset()
+ del encryption_key
+ remove_path(asset_path)
def setup_model_for_cc(model: Model):
@@ -55,42 +63,31 @@ def setup_model_for_cc(model: Model):
return
cc_config = model.get_cc_config()
cc_policy = model.get_cc_policy()
- if model.type != "ASSET":
+ if not model.is_asset():
raise MedperfException(
f"Model {model.id} is not a file-based asset and cannot be set up for confidential computing."
)
asset = model.asset_obj
- # create model asset
asset_path = asset.get_archive_path()
__verify_cloud_environment(cc_config)
- __setup_asset_for_cc(cc_config, cc_policy, asset_path, for_model=True)
-
-
-def __verify_cloud_environment(cc_config: dict):
- AssetStorageManager(cc_config, None, None).setup()
-
-def __setup_asset_for_cc(
- cc_config: dict,
- cc_policy: dict,
- asset_path: str,
- for_model: bool = False,
-):
- # create encryption key
+ # policy setup
+ medperf_config.ui.text = "Generating encryption key"
encryption_key = generate_encryption_key()
-
- asset_storage_manager = AssetStorageManager(cc_config, asset_path, encryption_key)
- asset_policy_manager = AssetPolicyManager(cc_config, for_model=for_model)
+ asset_policy_manager = AssetPolicyManager(cc_config, for_model=True)
+ asset_policy_manager.setup_policy(cc_policy, encryption_key)
# storage
+ asset_storage_manager = AssetStorageManager(cc_config, asset_path, encryption_key)
asset_storage_manager.store_asset()
-
- # policy setup
- asset_policy_manager.setup_policy(cc_policy, encryption_key)
del encryption_key
+def __verify_cloud_environment(cc_config: dict):
+ AssetStorageManager(cc_config, None, None).setup()
+
+
def update_dataset_cc_policy(dataset: Dataset, permitted_workloads: list[CCWorkloadID]):
if not dataset.is_cc_configured():
raise MedperfException(
@@ -108,7 +105,7 @@ def update_model_cc_policy(model: Model, permitted_workloads: list[CCWorkloadID]
f"Model {model.id} does not have a configuration for confidential computing."
)
cc_config = model.get_cc_config()
- if model.type != "ASSET":
+ if not model.is_asset():
raise MedperfException(
f"Model {model.id} is not a file-based asset and cannot be set up for confidential computing."
)
diff --git a/cli/medperf/asset_management/asset_storage_manager.py b/cli/medperf/asset_management/asset_storage_manager.py
index f772a3217..7363b9cb9 100644
--- a/cli/medperf/asset_management/asset_storage_manager.py
+++ b/cli/medperf/asset_management/asset_storage_manager.py
@@ -6,10 +6,15 @@
remove_path,
)
from medperf.encryption import SymmetricEncryption
-from medperf.asset_management.gcp_utils import GCPAssetConfig, upload_file_to_gcs
+from medperf.asset_management.gcp_utils import (
+ GCPAssetConfig,
+ upload_from_file_object_to_gcs,
+)
from medperf.asset_management.asset_check import verify_asset_owner_setup
+from medperf.asset_management.utils import CustomWriter, get_file_size
from medperf.exceptions import MedperfException
from medperf import config as medperf_config
+from tqdm import tqdm
class AssetStorageManager:
@@ -30,12 +35,25 @@ def __encrypt_asset(self):
asset_hash = get_file_hash(tmp_encrypted_asset_path)
return tmp_encrypted_asset_path, asset_hash
- def __upload_encrypted_asset(self, tmp_encrypted_asset_path):
- upload_file_to_gcs(
- self.config,
- tmp_encrypted_asset_path,
- self.config.encrypted_asset_bucket_file,
- )
+ def __upload_encrypted_asset(self, tmp_encrypted_asset_path: str):
+ with open(tmp_encrypted_asset_path, "rb") as in_file:
+ with tqdm.wrapattr(
+ in_file,
+ "read",
+ total=get_file_size(in_file),
+ miniters=1,
+ desc="Uploading encrypted dataset to the bucket",
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ file=CustomWriter(),
+ ) as file_obj:
+ upload_from_file_object_to_gcs(
+ self.config,
+ file_obj,
+ self.config.encrypted_asset_bucket_file,
+ )
+ remove_path(tmp_encrypted_asset_path)
def setup(self):
medperf_config.ui.text = "Verifying Cloud Environment"
diff --git a/cli/medperf/asset_management/gcp_utils/__init__.py b/cli/medperf/asset_management/gcp_utils/__init__.py
index fcb456b39..b6f4a4bc0 100644
--- a/cli/medperf/asset_management/gcp_utils/__init__.py
+++ b/cli/medperf/asset_management/gcp_utils/__init__.py
@@ -2,6 +2,7 @@
from .kms import set_kms_iam_policy, encrypt_with_kms_key
from .storage import (
upload_file_to_gcs,
+ upload_from_file_object_to_gcs,
upload_string_to_gcs,
download_file_from_gcs,
download_string_from_gcs,
@@ -19,6 +20,7 @@
"set_kms_iam_policy",
"encrypt_with_kms_key",
"upload_file_to_gcs",
+ "upload_from_file_object_to_gcs",
"upload_string_to_gcs",
"download_file_from_gcs",
"download_string_from_gcs",
diff --git a/cli/medperf/asset_management/gcp_utils/storage.py b/cli/medperf/asset_management/gcp_utils/storage.py
index e971c3b46..f4b4e1de0 100644
--- a/cli/medperf/asset_management/gcp_utils/storage.py
+++ b/cli/medperf/asset_management/gcp_utils/storage.py
@@ -13,6 +13,16 @@ def upload_file_to_gcs(
blob.upload_from_filename(local_file)
+def upload_from_file_object_to_gcs(
+ config: Union[GCPAssetConfig, GCPOperatorConfig], file: object, gcs_path: str
+):
+ """Upload file to Google Cloud Storage."""
+ client = storage.Client()
+ bucket = client.bucket(config.bucket)
+ blob = bucket.blob(gcs_path)
+ blob.upload_from_file(file)
+
+
def upload_string_to_gcs(
config: Union[GCPAssetConfig, GCPOperatorConfig], content: bytes, gcs_path: str
):
diff --git a/cli/medperf/asset_management/utils.py b/cli/medperf/asset_management/utils.py
new file mode 100644
index 000000000..9fb61cd97
--- /dev/null
+++ b/cli/medperf/asset_management/utils.py
@@ -0,0 +1,21 @@
+from medperf import config
+import os
+
+
+class CustomWriter:
+ """class to use with tqdm to print progress using config.ui"""
+
+ def write(self, msg):
+ config.ui.print(msg)
+
+ def flush(self):
+ pass
+
+
+def get_file_size(file_object) -> int:
+ """Get the size of a file in bytes."""
+ try:
+ total_bytes = os.fstat(file_object.fileno()).st_size
+ except (AttributeError, OSError):
+ total_bytes = None
+ return total_bytes
diff --git a/cli/medperf/commands/aggregator/aggregator.py b/cli/medperf/commands/aggregator/aggregator.py
index b0cb93531..7c751151b 100644
--- a/cli/medperf/commands/aggregator/aggregator.py
+++ b/cli/medperf/commands/aggregator/aggregator.py
@@ -5,7 +5,6 @@
import medperf.config as config
from medperf.decorators import clean_except
from medperf.commands.aggregator.submit import SubmitAggregator
-from medperf.commands.aggregator.associate import AssociateAggregator
from medperf.commands.aggregator.run import StartAggregator
from medperf.commands.list import EntityList
@@ -33,22 +32,6 @@ def submit(
config.ui.print("✅ Done!")
-@app.command("associate")
-@clean_except
-def associate(
- aggregator_id: int = typer.Option(
- ..., "--aggregator_id", "-a", help="UID of benchmark to associate with"
- ),
- training_exp_id: int = typer.Option(
- ..., "--training_exp_id", "-t", help="UID of benchmark to associate with"
- ),
- approval: bool = typer.Option(False, "-y", help="Skip approval step"),
-):
- """Associates an aggregator with a training experiment."""
- AssociateAggregator.run(aggregator_id, training_exp_id, approved=approval)
- config.ui.print("✅ Done!")
-
-
@app.command("start")
@clean_except
def run(
diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py
index 76f4179d9..62c611430 100644
--- a/cli/medperf/commands/aggregator/run.py
+++ b/cli/medperf/commands/aggregator/run.py
@@ -24,14 +24,14 @@ def run(
training_exp_id (int): Training experiment UID.
"""
execution = cls(training_exp_id, publish_on, overwrite)
- execution.prepare()
- execution.validate()
- execution.check_existing_outputs()
- execution.prepare_aggregator()
- execution.prepare_participants_list()
- execution.prepare_plan()
- execution.prepare_pki_assets()
with config.ui.interactive():
+ execution.prepare()
+ execution.validate()
+ execution.check_existing_outputs()
+ execution.prepare_aggregator()
+ execution.prepare_participants_list()
+ execution.prepare_plan()
+ execution.prepare_pki_assets()
execution.run_experiment()
def __init__(self, training_exp_id, publish_on, overwrite) -> None:
diff --git a/cli/medperf/commands/aggregator/submit.py b/cli/medperf/commands/aggregator/submit.py
index 3fd653fcb..f93655122 100644
--- a/cli/medperf/commands/aggregator/submit.py
+++ b/cli/medperf/commands/aggregator/submit.py
@@ -23,6 +23,7 @@ def run(cls, name: str, address: str, port: int, aggregation_mlcube: int):
updated_benchmark_body = submission.submit()
ui.print("Uploaded")
submission.write(updated_benchmark_body)
+ return submission.aggregator.id
def __init__(self, name: str, address: str, port: int, aggregation_mlcube: int):
self.ui = config.ui
@@ -41,5 +42,5 @@ def submit(self):
def write(self, updated_body):
remove_path(self.aggregator.path)
- aggregator = Aggregator(**updated_body)
- aggregator.write()
+ self.aggregator = Aggregator(**updated_body)
+ self.aggregator.write()
diff --git a/cli/medperf/commands/association/approval.py b/cli/medperf/commands/association/approval.py
index 1c78f7dd0..ebbb4f6fa 100644
--- a/cli/medperf/commands/association/approval.py
+++ b/cli/medperf/commands/association/approval.py
@@ -10,15 +10,14 @@ def run(
training_exp_uid: int = None,
dataset_uid: int = None,
model_uid: int = None,
- aggregator_uid: int = None,
):
- """Sets approval status for an association between a benchmark and a dataset or mlcube
+ """Sets approval status for an association between a benchmark and a dataset or mlcube,
+ or between a training experiment and a dataset.
Args:
benchmark_uid (int): Benchmark UID.
approval_status (str): Desired approval status to set for the association.
- comms (Comms): Instance of Comms interface.
- ui (UI): Instance of UI interface.
+ training_exp_uid (int, optional): Training experiment UID. Defaults to None.
dataset_uid (int, optional): Dataset UID. Defaults to None.
mlcube_uid (int, optional): MLCube UID. Defaults to None.
"""
@@ -28,7 +27,6 @@ def run(
training_exp_uid,
dataset_uid,
model_uid,
- aggregator_uid,
approval_status.value,
)
update = {"approval_status": approval_status.value}
@@ -42,12 +40,7 @@ def run(
comms.update_benchmark_model_association(
benchmark_uid, model_uid, update
)
- if training_exp_uid:
- if dataset_uid:
- comms.update_training_dataset_association(
- training_exp_uid, dataset_uid, update
- )
- if aggregator_uid:
- comms.update_training_aggregator_association(
- training_exp_uid, aggregator_uid, update
- )
+ if training_exp_uid and dataset_uid:
+ comms.update_training_dataset_association(
+ training_exp_uid, dataset_uid, update
+ )
diff --git a/cli/medperf/commands/association/association.py b/cli/medperf/commands/association/association.py
index a6f52346a..4bdfccc08 100644
--- a/cli/medperf/commands/association/association.py
+++ b/cli/medperf/commands/association/association.py
@@ -17,7 +17,6 @@ def list(
training_exp: bool = typer.Option(False, "-t", help="list training associations"),
dataset: bool = typer.Option(False, "-d", help="list dataset associations"),
model: bool = typer.Option(False, "-m", help="list models associations"),
- aggregator: bool = typer.Option(False, "-a", help="list aggregator associations"),
approval_status: str = typer.Option(
None, "--approval-status", help="Approval status"
),
@@ -28,9 +27,7 @@ def list(
filter (str, optional): Filter associations by approval status.
Defaults to displaying all user associations.
"""
- ListAssociations.run(
- approval_status, benchmark, training_exp, dataset, model, aggregator
- )
+ ListAssociations.run(approval_status, benchmark, training_exp, dataset, model)
@app.command("approve")
@@ -42,9 +39,6 @@ def approve(
),
dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"),
model_uid: int = typer.Option(None, "--model", "-m", help="Model container UID"),
- aggregator_uid: int = typer.Option(
- None, "--aggregator", "-a", help="Aggregator UID"
- ),
):
"""Approves an association between a benchmark and a dataset or model container
@@ -59,7 +53,6 @@ def approve(
training_exp_uid,
dataset_uid,
model_uid,
- aggregator_uid,
)
config.ui.print("✅ Done!")
@@ -73,9 +66,6 @@ def reject(
),
dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"),
model_uid: int = typer.Option(None, "--model", "-m", help="Model container UID"),
- aggregator_uid: int = typer.Option(
- None, "--aggregator", "-a", help="Aggregator UID"
- ),
):
"""Rejects an association between a benchmark and a dataset or model container
@@ -90,7 +80,6 @@ def reject(
training_exp_uid,
dataset_uid,
model_uid,
- aggregator_uid,
)
config.ui.print("✅ Done!")
diff --git a/cli/medperf/commands/association/list.py b/cli/medperf/commands/association/list.py
index cce94d892..82c62ca87 100644
--- a/cli/medperf/commands/association/list.py
+++ b/cli/medperf/commands/association/list.py
@@ -12,12 +12,9 @@ def run(
training_exp=False,
dataset=False,
model=False,
- aggregator=False,
):
"""Get user association requests"""
- validate_args(
- benchmark, training_exp, dataset, model, aggregator, approval_status
- )
+ validate_args(benchmark, training_exp, dataset, model, approval_status)
if training_exp:
experiment_type = "training_exp"
elif benchmark:
@@ -27,8 +24,6 @@ def run(
component_type = "model"
elif dataset:
component_type = "dataset"
- elif aggregator:
- component_type = "aggregator"
assocs = get_user_associations(experiment_type, component_type, approval_status)
diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py
index 80a5ffbea..bbcf20a3c 100644
--- a/cli/medperf/commands/association/utils.py
+++ b/cli/medperf/commands/association/utils.py
@@ -3,12 +3,11 @@
from pydantic.datetime_parse import parse_datetime
-def validate_args(benchmark, training_exp, dataset, model, aggregator, approval_status):
+def validate_args(benchmark, training_exp, dataset, model, approval_status):
training_exp = bool(training_exp)
benchmark = bool(benchmark)
dataset = bool(dataset)
model = bool(model)
- aggregator = bool(aggregator)
if approval_status is not None:
if approval_status.lower() not in ["pending", "approved", "rejected"]:
@@ -19,20 +18,13 @@ def validate_args(benchmark, training_exp, dataset, model, aggregator, approval_
raise InvalidArgumentError(
"One training experiment or a benchmark flag must be provided"
)
- if sum([dataset, model, aggregator]) != 1:
- raise InvalidArgumentError(
- "One dataset, model, or aggregator flag must be provided"
- )
+ if sum([dataset, model]) != 1:
+ raise InvalidArgumentError("One dataset or model flag must be provided")
if training_exp and model:
raise InvalidArgumentError(
"Invalid combination of arguments. There are no associations"
" between training experiments and models"
)
- if benchmark and aggregator:
- raise InvalidArgumentError(
- "Invalid combination of arguments. There are no associations"
- " between benchmarks and aggregators"
- )
def filter_latest_associations(associations, experiment_key, component_key):
@@ -61,17 +53,6 @@ def filter_latest_associations(associations, experiment_key, component_key):
return latest_associations
-def get_last_component(associations, experiment_key):
- associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"]))
- experiments_component = {}
- for assoc in associations:
- experiment_id = assoc[experiment_key]
- experiments_component[experiment_id] = assoc
-
- experiments_component = list(experiments_component.values())
- return experiments_component
-
-
def get_experiment_associations(
experiment_id: int,
experiment_type: str,
@@ -112,7 +93,6 @@ def get_user_associations(
comms_functions = {
"training_exp": {
"dataset": config.comms.get_user_training_datasets_associations,
- "aggregator": config.comms.get_user_training_aggregators_associations,
},
"benchmark": {
"dataset": config.comms.get_user_benchmarks_datasets_associations,
@@ -140,10 +120,6 @@ def _post_process_associtations(
):
assocs = filter_latest_associations(associations, experiment_type, component_type)
- if component_type == "aggregator":
- # an experiment should only have one aggregator
- assocs = get_last_component(assocs, experiment_type)
-
if approval_status:
approval_status = approval_status.upper()
assocs = [
diff --git a/cli/medperf/commands/cc/dataset_configure_for_cc.py b/cli/medperf/commands/cc/dataset_configure_for_cc.py
index 49dd023b2..c2662d75e 100644
--- a/cli/medperf/commands/cc/dataset_configure_for_cc.py
+++ b/cli/medperf/commands/cc/dataset_configure_for_cc.py
@@ -5,6 +5,7 @@
)
import json
from medperf import config
+from medperf.exceptions import InvalidEntityError
class DatasetConfigureForCC:
@@ -19,10 +20,18 @@ def run_from_files(cls, data_uid: int, cc_config_file: str, cc_policy_file: str)
@classmethod
def run(cls, data_uid: int, cc_config: dict, cc_policy: dict):
validate_cc_config(cc_config, "dataset" + str(data_uid))
+ dataset = Dataset.get(data_uid)
+ dataset.set_cc_config(cc_config)
+ dataset.set_cc_policy(cc_policy)
+ body = {"user_metadata": dataset.user_metadata}
+ config.comms.update_dataset(dataset.id, body)
with config.ui.interactive():
- dataset = Dataset.get(data_uid)
- dataset.set_cc_config(cc_config)
- dataset.set_cc_policy(cc_policy)
+ config.ui.text = "Checking dataset hash"
+ if not dataset.check_hash():
+ raise InvalidEntityError(
+ "Dataset hash does not match the one stored in the system."
+ )
setup_dataset_for_cc(dataset)
+ dataset.set_cc_initialized()
body = {"user_metadata": dataset.user_metadata}
config.comms.update_dataset(dataset.id, body)
diff --git a/cli/medperf/commands/cc/dataset_update_cc_policy.py b/cli/medperf/commands/cc/dataset_update_cc_policy.py
index 427323869..8799b4832 100644
--- a/cli/medperf/commands/cc/dataset_update_cc_policy.py
+++ b/cli/medperf/commands/cc/dataset_update_cc_policy.py
@@ -69,6 +69,10 @@ def run(cls, data_uid: int):
raise MedperfException(
f"Dataset {dataset.id} is not configured for confidential computing."
)
- permitted_workloads = get_permitted_workloads(dataset)
-
- update_dataset_cc_policy(dataset, permitted_workloads)
+ with config.ui.interactive():
+ config.ui.text = "Updating dataset confidential computing policy"
+ permitted_workloads = get_permitted_workloads(dataset)
+ update_dataset_cc_policy(dataset, permitted_workloads)
+ dataset.set_last_synced()
+ body = {"user_metadata": dataset.user_metadata}
+ config.comms.update_dataset(dataset.id, body)
diff --git a/cli/medperf/commands/cc/model_configure_for_cc.py b/cli/medperf/commands/cc/model_configure_for_cc.py
index be950931f..34bbc809c 100644
--- a/cli/medperf/commands/cc/model_configure_for_cc.py
+++ b/cli/medperf/commands/cc/model_configure_for_cc.py
@@ -5,6 +5,7 @@
)
import json
from medperf import config
+from medperf.exceptions import InvalidEntityError
class ModelConfigureForCC:
@@ -19,10 +20,18 @@ def run_from_files(cls, model_uid: int, cc_config_file: str, cc_policy_file: str
@classmethod
def run(cls, model_uid: int, cc_config: dict, cc_policy: dict):
validate_cc_config(cc_config, "model" + str(model_uid))
+ model = Model.get(model_uid)
+ model.set_cc_config(cc_config)
+ model.set_cc_policy(cc_policy)
+ body = {"user_metadata": model.user_metadata}
+ config.comms.update_model(model.id, body)
with config.ui.interactive():
- model = Model.get(model_uid)
- model.set_cc_config(cc_config)
- model.set_cc_policy(cc_policy)
+ config.ui.text = "Checking model hash"
+ if not model.check_hash():
+ raise InvalidEntityError(
+ "Model hash does not match the one stored in the system."
+ )
setup_model_for_cc(model)
+ model.set_cc_initialized()
body = {"user_metadata": model.user_metadata}
config.comms.update_model(model.id, body)
diff --git a/cli/medperf/commands/cc/model_update_cc_policy.py b/cli/medperf/commands/cc/model_update_cc_policy.py
index 8c464c240..8fbd02218 100644
--- a/cli/medperf/commands/cc/model_update_cc_policy.py
+++ b/cli/medperf/commands/cc/model_update_cc_policy.py
@@ -99,5 +99,10 @@ def run(cls, model_uid: int):
raise MedperfException(
f"Model {model.id} is not configured for confidential computing."
)
- permitted_workloads = get_permitted_workloads_without_datasets(model)
- update_model_cc_policy(model, permitted_workloads)
+ with config.ui.interactive():
+ config.ui.text = "Updating model confidential computing policy"
+ permitted_workloads = get_permitted_workloads_without_datasets(model)
+ update_model_cc_policy(model, permitted_workloads)
+ model.set_last_synced()
+ body = {"user_metadata": model.user_metadata}
+ config.comms.update_model(model.id, body)
diff --git a/cli/medperf/commands/cc/setup_cc_operator.py b/cli/medperf/commands/cc/setup_cc_operator.py
index a69a29aae..c36dd7937 100644
--- a/cli/medperf/commands/cc/setup_cc_operator.py
+++ b/cli/medperf/commands/cc/setup_cc_operator.py
@@ -17,9 +17,13 @@ def run_from_files(cls, cc_config_file: str):
@classmethod
def run(cls, cc_config: dict):
validate_cc_operator_config(cc_config)
+ user = get_medperf_user_object()
+ user.set_cc_config(cc_config)
+ body = {"metadata": user.metadata}
+ config.comms.update_user(user.id, body)
+
with config.ui.interactive():
- user = get_medperf_user_object()
- user.set_cc_config(cc_config)
setup_operator(user)
+ user.set_cc_initialized()
body = {"metadata": user.metadata}
config.comms.update_user(user.id, body)
diff --git a/cli/medperf/commands/certificate/server_certificate.py b/cli/medperf/commands/certificate/server_certificate.py
index 418020c7f..83165e9ca 100644
--- a/cli/medperf/commands/certificate/server_certificate.py
+++ b/cli/medperf/commands/certificate/server_certificate.py
@@ -22,4 +22,6 @@ def run(aggregator_id: int, overwrite: bool = False):
"Cert and key already present. Rerun the command with --overwrite"
)
remove_path(output_path, sensitive=True)
- get_server_cert(ca, address, output_path)
+
+ with config.ui.interactive():
+ get_server_cert(ca, address, output_path)
diff --git a/cli/medperf/commands/dataset/associate_benchmark.py b/cli/medperf/commands/dataset/associate_benchmark.py
index ea2a41077..9f47574e4 100644
--- a/cli/medperf/commands/dataset/associate_benchmark.py
+++ b/cli/medperf/commands/dataset/associate_benchmark.py
@@ -39,10 +39,10 @@ def run(data_uid: int, benchmark_uid: int, approved=False, no_cache=False):
no_cache=no_cache,
)[0]
results = execution.read_results()
- ui.print("These are the results generated by the compatibility test. ")
- ui.print("This will be sent along the association request.")
- ui.print("They will not be part of the benchmark.")
- dict_pretty_print(results)
+ ui.print("These are the results generated by the compatibility test. ")
+ ui.print("This will be sent along the association request.")
+ ui.print("They will not be part of the benchmark.")
+ dict_pretty_print(results)
msg = "Please confirm that you would like to associate"
msg += f" the dataset {dset.name} with the benchmark {benchmark.name}."
diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py
index 4aac247c5..4004bfa1a 100644
--- a/cli/medperf/commands/dataset/train.py
+++ b/cli/medperf/commands/dataset/train.py
@@ -46,13 +46,13 @@ def run(
execution.confirm_restart_on_failure()
while True:
- execution.prepare()
- execution.validate()
- execution.check_existing_outputs()
- execution.prepare_plan()
- execution.prepare_pki_assets()
- execution.confirm_run()
with config.ui.interactive():
+ execution.prepare()
+ execution.validate()
+ execution.check_existing_outputs()
+ execution.prepare_plan()
+ execution.prepare_pki_assets()
+ execution.confirm_run()
execution.prepare_training_cube()
try:
execution.run_experiment()
diff --git a/cli/medperf/commands/execution/execution_flow.py b/cli/medperf/commands/execution/execution_flow.py
index 3a63a2a9a..ac1f4648b 100644
--- a/cli/medperf/commands/execution/execution_flow.py
+++ b/cli/medperf/commands/execution/execution_flow.py
@@ -1,3 +1,4 @@
+from medperf import config
from medperf.entities.cube import Cube
from medperf.entities.model import Model
from medperf.entities.dataset import Dataset
@@ -67,7 +68,8 @@ def run(
)
else:
container = model.container_obj
- container.download_run_files()
+ with config.ui.interactive():
+ container.download_run_files()
return ContainerExecution.run(
dataset, container, evaluator, execution, ignore_model_errors
)
diff --git a/cli/medperf/commands/model/associate.py b/cli/medperf/commands/model/associate.py
index e282229c7..e580f7737 100644
--- a/cli/medperf/commands/model/associate.py
+++ b/cli/medperf/commands/model/associate.py
@@ -33,10 +33,10 @@ def run(
_, results = CompatibilityTestExecution.run(
benchmark=benchmark_uid, model=model_uid, no_cache=no_cache
)
- ui.print("These are the results generated by the compatibility test. ")
- ui.print("This will be sent along the association request.")
- ui.print("They will not be part of the benchmark.")
- dict_pretty_print(results)
+ ui.print("These are the results generated by the compatibility test. ")
+ ui.print("This will be sent along the association request.")
+ ui.print("They will not be part of the benchmark.")
+ dict_pretty_print(results)
msg = "Please confirm that you would like to associate "
msg += f"the model '{model.name}' with the benchmark '{benchmark.name}' [Y/n]"
diff --git a/cli/medperf/commands/training/get_experiment_status.py b/cli/medperf/commands/training/get_experiment_status.py
index 2df1bf445..0c9077f1b 100644
--- a/cli/medperf/commands/training/get_experiment_status.py
+++ b/cli/medperf/commands/training/get_experiment_status.py
@@ -23,10 +23,10 @@ def run(cls, training_exp_id: int, silent: bool = False):
training_exp_id (int): Training experiment UID.
"""
execution = cls(training_exp_id)
- execution.prepare()
- execution.prepare_plan()
- execution.prepare_pki_assets()
with config.ui.interactive():
+ execution.prepare()
+ execution.prepare_plan()
+ execution.prepare_pki_assets()
execution.prepare_admin_cube()
execution.get_experiment_status()
if not silent:
diff --git a/cli/medperf/commands/training/set_aggregator.py b/cli/medperf/commands/training/set_aggregator.py
new file mode 100644
index 000000000..dc01cc8c0
--- /dev/null
+++ b/cli/medperf/commands/training/set_aggregator.py
@@ -0,0 +1,58 @@
+from medperf import config
+from medperf.entities.aggregator import Aggregator
+from medperf.entities.training_exp import TrainingExp
+from medperf.utils import approval_prompt
+from medperf.exceptions import CleanExit, InvalidArgumentError
+
+
+class SetAggregator:
+ @classmethod
+ def run(cls, training_exp_id: int, aggregator_id: int, approved: bool = False):
+ """Sets the aggregator for a training experiment.
+
+ Args:
+ training_exp_id (int): UID of the training experiment
+ aggregator_id (int): UID of the registered aggregator to set
+ approved (bool): If True, skip confirmation prompt
+ """
+ ui = config.ui
+
+ submission = cls(training_exp_id, aggregator_id, approved)
+ with ui.interactive():
+ submission.validate()
+ submission.prepare()
+ submission.submit()
+
+ def __init__(self, training_exp_id: int, aggregator_id: int, approval: bool):
+ self.ui = config.ui
+ self.comms = config.comms
+ self.training_exp_id = training_exp_id
+ self.aggregator_id = aggregator_id
+ self.approved = approval
+
+ def validate(self):
+ if self.aggregator_id is None:
+ raise InvalidArgumentError("An aggregator ID must be provided")
+ if self.training_exp_id is None:
+ raise InvalidArgumentError("A training experiment ID must be provided")
+
+ def prepare(self):
+ self.training_exp = TrainingExp.get(self.training_exp_id)
+ self.aggregator = Aggregator.get(self.aggregator_id)
+
+ def submit(self):
+ training_exp_name = self.training_exp.name
+ self.ui.text = (
+ f"Setting aggregator for training experiment '{training_exp_name}'"
+ )
+ body = {"aggregator": self.aggregator_id}
+ msg = (
+ f"You are about to set the aggregator {self.aggregator.name} for training experiment {training_exp_name}."
+ " Do you confirm? [Y/n] "
+ )
+ self.approved = self.approved or approval_prompt(msg)
+
+ if self.approved:
+ self.comms.update_training_exp(self.training_exp_id, body)
+ return
+ raise CleanExit("Aggregator setting cancelled")
diff --git a/cli/medperf/commands/training/set_plan.py b/cli/medperf/commands/training/set_plan.py
index aed8aa884..787ae061f 100644
--- a/cli/medperf/commands/training/set_plan.py
+++ b/cli/medperf/commands/training/set_plan.py
@@ -3,7 +3,12 @@
from medperf.entities.training_exp import TrainingExp
from medperf.entities.cube import Cube
from medperf.exceptions import CleanExit, InvalidArgumentError
-from medperf.utils import approval_prompt, dict_pretty_print, generate_tmp_path
+from medperf.utils import (
+ approval_prompt,
+ dict_pretty_print,
+ generate_tmp_path,
+ sanitize_path,
+)
import os
import yaml
@@ -22,12 +27,14 @@ def run(
"""
planset = cls(training_exp_id, training_config_path, approval)
planset.validate()
- planset.prepare()
- planset.create_plan()
+ with config.ui.interactive():
+ planset.prepare()
+ planset.create_plan()
planset.update()
planset.write()
def __init__(self, training_exp_id: int, training_config_path: str, approval: bool):
+ training_config_path = sanitize_path(training_config_path)
self.ui = config.ui
self.training_exp_id = training_exp_id
self.training_config_path = os.path.abspath(training_config_path)
diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py
index 18f6400aa..61b72268b 100644
--- a/cli/medperf/commands/training/start_event.py
+++ b/cli/medperf/commands/training/start_event.py
@@ -1,6 +1,11 @@
from medperf.entities.training_exp import TrainingExp
from medperf.entities.event import TrainingEvent
-from medperf.utils import approval_prompt, dict_pretty_print, get_participant_label
+from medperf.utils import (
+ approval_prompt,
+ dict_pretty_print,
+ get_participant_label,
+ sanitize_path,
+)
from medperf.exceptions import CleanExit, InvalidArgumentError
import yaml
import os
@@ -27,7 +32,7 @@ def __init__(
):
self.training_exp_id = training_exp_id
self.name = name
- self.participants_list_file = participants_list_file
+ self.participants_list_file = sanitize_path(participants_list_file)
self.approved = approval
def prepare(self):
diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py
index dae474160..394cf0242 100644
--- a/cli/medperf/commands/training/submit.py
+++ b/cli/medperf/commands/training/submit.py
@@ -1,4 +1,5 @@
import medperf.config as config
+from medperf.entities.aggregator import Aggregator
from medperf.entities.training_exp import TrainingExp
from medperf.entities.cube import Cube
from medperf.utils import remove_path
@@ -29,10 +30,13 @@ def run(cls, training_exp_info: dict):
ui.text = "Getting FL admin Container"
submission.get_fl_admin_mlcube()
ui.print("> Completed retrieving FL Container")
+ ui.text = "Checking if an aggregator is provided"
+ submission.get_aggregator()
ui.text = "Submitting TrainingExp to MedPerf"
updated_benchmark_body = submission.submit()
ui.print("Uploaded")
submission.write(updated_benchmark_body)
+ return submission.training_exp.id
def __init__(self, training_exp_info: dict):
self.ui = config.ui
@@ -48,11 +52,16 @@ def get_fl_admin_mlcube(self):
if mlcube_id:
Cube.get(mlcube_id)
+ def get_aggregator(self):
+ aggregator_id = self.training_exp.aggregator
+ if aggregator_id:
+ Aggregator.get(aggregator_id)
+
def submit(self):
updated_body = self.training_exp.upload()
return updated_body
def write(self, updated_body):
remove_path(self.training_exp.path)
- training_exp = TrainingExp(**updated_body)
- training_exp.write()
+ self.training_exp = TrainingExp(**updated_body)
+ self.training_exp.write()
diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py
index 38db45f80..5da34365f 100644
--- a/cli/medperf/commands/training/training.py
+++ b/cli/medperf/commands/training/training.py
@@ -13,6 +13,7 @@
from medperf.commands.view import EntityView
from medperf.commands.training.get_experiment_status import GetExperimentStatus
from medperf.commands.training.update_plan import UpdatePlan
+from medperf.commands.training.set_aggregator import SetAggregator
app = typer.Typer()
@@ -37,6 +38,9 @@ def submit(
"--operational",
help="Submit the experiment as OPERATIONAL",
),
+ aggregator: int = typer.Option(
+ None, "--aggregator", "-g", help="UID of the registered aggregator to set"
+ ),
):
"""Submits a new benchmark to the platform"""
training_exp_info = {
@@ -45,6 +49,7 @@ def submit(
"docs_url": docs_url,
"fl_mlcube": fl_mlcube,
"fl_admin_mlcube": fl_admin_mlcube,
+ "aggregator": aggregator,
"demo_dataset_tarball_url": "link",
"demo_dataset_tarball_hash": "hash",
"demo_dataset_generated_uid": "uid",
@@ -119,6 +124,22 @@ def update_plan(
config.ui.print("✅ Done!")
+@app.command("set_aggregator")
+@clean_except
+def set_aggregator(
+ training_exp_id: int = typer.Option(
+ ..., "--training_exp_id", "-t", help="UID of the training experiment"
+ ),
+ aggregator_id: int = typer.Option(
+ ..., "--aggregator_id", "-a", help="UID of the aggregator to set"
+ ),
+ approval: bool = typer.Option(False, "-y", help="Skip approval step"),
+):
+ """Set the aggregator for a training experiment."""
+ SetAggregator.run(training_exp_id, aggregator_id, approved=approval)
+ config.ui.print("✅ Done!")
+
+
@app.command("close_event")
@clean_except
def close_event(
diff --git a/cli/medperf/comms/auth/auth0.py b/cli/medperf/comms/auth/auth0.py
index e7cf74c9c..0135cf6a0 100644
--- a/cli/medperf/comms/auth/auth0.py
+++ b/cli/medperf/comms/auth/auth0.py
@@ -47,7 +47,7 @@ def login(self, email):
)
config.ui.print_code(user_code)
config.ui.print("\n\n")
- config.ui.print_warning(
+ config.ui.print_cli_warning(
"Keep this terminal open until you complete your login request. "
"The command will exit on its own once you complete the request. "
"If you wish to stop the login request anyway, press Ctrl+C."
diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py
index f3127e160..424e831db 100644
--- a/cli/medperf/comms/rest.py
+++ b/cli/medperf/comms/rest.py
@@ -434,6 +434,21 @@ def get_aggregators(self, filters=dict()) -> List[dict]:
error_msg = "Could not retrieve aggregators"
return self.__get_list(url, filters=filters, error_msg=error_msg)
+ def get_aggregator_training_experiments(
+ self, aggregator_id: int, filters=dict()
+ ) -> List[dict]:
+ """Retrieves training experiments that have the given aggregator set.
+
+ Args:
+ aggregator_id (int): Aggregator UID
+
+ Returns:
+ List[dict]: List of training experiment data
+ """
+ url = f"{self.server_url}/aggregators/{aggregator_id}/training_experiments/"
+ error_msg = "Could not retrieve training experiments for aggregator"
+ return self.__get_list(url, filters=filters, error_msg=error_msg)
+
def get_cas(self, filters=dict()) -> List[dict]:
"""Retrieves all cas
@@ -641,16 +656,6 @@ def get_user_training_datasets_associations(self, filters=dict()) -> List[dict]:
error_msg = "Could not retrieve user datasets training associations"
return self.__get_list(url, filters=filters, error_msg=error_msg)
- def get_user_training_aggregators_associations(self, filters=dict()) -> List[dict]:
- """Get all aggregator associations related to the current user
-
- Returns:
- List[dict]: List containing all associations information
- """
- url = f"{self.server_url}/me/aggregators/training_associations/"
- error_msg = "Could not retrieve user aggregators training associations"
- return self.__get_list(url, filters=filters, error_msg=error_msg)
-
# upload
def upload_benchmark(self, benchmark_dict: dict) -> int:
"""Uploads a new benchmark to the server.
@@ -875,22 +880,6 @@ def associate_training_dataset(self, data_uid: int, training_exp_id: int):
error_msg = "Could not associate dataset to training_exp"
return self.__post(url, json=data, error_msg=error_msg)
- def associate_training_aggregator(self, aggregator_id: int, training_exp_id: int):
- """Create a aggregator experiment association
-
- Args:
- aggregator_id (int): Registered aggregator UID
- training_exp_id (int): training experiment UID
- """
- url = f"{self.server_url}/aggregators/training/"
- data = {
- "aggregator": aggregator_id,
- "training_exp": training_exp_id,
- "approval_status": Status.PENDING.value,
- }
- error_msg = "Could not associate aggregator to training_exp"
- return self.__post(url, json=data, error_msg=error_msg)
-
# updates associations
def update_benchmark_dataset_association(
self, benchmark_uid: int, dataset_uid: int, data: str
@@ -920,25 +909,6 @@ def update_benchmark_model_association(
error_msg = f"Could not update association: model {model_uid}, benchmark {benchmark_uid}"
self.__put(url, json=data, error_msg=error_msg)
- def update_training_aggregator_association(
- self, training_exp_id: int, aggregator_id: int, data: dict
- ):
- """Approves a aggregator association
-
- Args:
- dataset_uid (int): Dataset UID
- benchmark_uid (int): Benchmark UID
- status (str): Approval status to set for the association
- """
- url = (
- f"{self.server_url}/aggregators/{aggregator_id}/training/{training_exp_id}/"
- )
- error_msg = (
- "Could not update association: aggregator"
- f" {aggregator_id}, training_exp {training_exp_id}"
- )
- self.__put(url, json=data, error_msg=error_msg)
-
def update_training_dataset_association(
self, training_exp_id: int, dataset_uid: int, data: dict
):
diff --git a/cli/medperf/config.py b/cli/medperf/config.py
index 4434c96e1..50e03379e 100644
--- a/cli/medperf/config.py
+++ b/cli/medperf/config.py
@@ -353,3 +353,6 @@
# Data Import/Export config
archive_config_filename = "config.yaml"
+
+# Running containers processes
+running_containers = {}
diff --git a/cli/medperf/containers/runners/docker_runner.py b/cli/medperf/containers/runners/docker_runner.py
index dc6ce8f20..61b9ff80e 100644
--- a/cli/medperf/containers/runners/docker_runner.py
+++ b/cli/medperf/containers/runners/docker_runner.py
@@ -91,6 +91,8 @@ def run(
run_args = self.parser.get_run_args(task)
check_allowed_run_args(run_args)
+ run_args["task"] = task
+
add_medperf_run_args(run_args)
add_medperf_environment_variables(run_args, medperf_env)
add_user_defined_run_args(run_args)
@@ -182,8 +184,9 @@ def _run_encrypted_archive(
def _invoke_run(self, image, run_args, timeout, output_logs):
run_args["image"] = image
+ task = run_args.pop("task")
# Run
command = craft_docker_run_command(run_args)
logging.debug("Running docker container")
- run_command(command, timeout, output_logs)
+ run_command(command, timeout, output_logs, task=task)
diff --git a/cli/medperf/containers/runners/singularity_runner.py b/cli/medperf/containers/runners/singularity_runner.py
index c46a8d2fd..cf9d6223e 100644
--- a/cli/medperf/containers/runners/singularity_runner.py
+++ b/cli/medperf/containers/runners/singularity_runner.py
@@ -124,6 +124,8 @@ def run(
run_args = self.parser.get_run_args(task)
check_allowed_run_args(run_args)
+ run_args["task"] = task
+
add_medperf_run_args(run_args)
add_medperf_environment_variables(run_args, medperf_env)
add_user_defined_run_args(run_args)
@@ -285,8 +287,9 @@ def _run_encrypted_docker_archive(
def _invoke_run(self, image, run_args, timeout, output_logs):
run_args["image"] = image
+ task = run_args.pop("task")
# Run
command = craft_singularity_run_command(run_args, self.executable)
logging.debug("Running singulairty container")
- run_command(command, timeout, output_logs)
+ run_command(command, timeout, output_logs, task=task)
diff --git a/cli/medperf/entities/aggregator.py b/cli/medperf/entities/aggregator.py
index 612792754..a3e2ec375 100644
--- a/cli/medperf/entities/aggregator.py
+++ b/cli/medperf/entities/aggregator.py
@@ -1,4 +1,7 @@
import os
+from typing import List
+
+from medperf.entities.training_exp import TrainingExp
from medperf.entities.interface import Entity
from medperf.entities.schemas import AggregatorSchema
@@ -70,6 +73,18 @@ def from_experiment(cls, training_exp_uid: int) -> "Aggregator":
agg.write()
return agg
+ def get_training_experiments(self) -> List["TrainingExp"]:
+ """Training experiments that have this aggregator set (reverse relation).
+
+ Returns:
+ List[TrainingExp]: Training experiments with this aggregator set.
+ """
+
+ if self.id is None:
+ return []
+ data_list = config.comms.get_aggregator_training_experiments(self.id)
+ return [TrainingExp(**d) for d in data_list]
+
@classmethod
def remote_prefilter(cls, filters: dict) -> callable:
"""Applies filtering logic that must be done before retrieving remote entities
diff --git a/cli/medperf/entities/asset.py b/cli/medperf/entities/asset.py
index b9d4ba1d8..fa8d5a2a0 100644
--- a/cli/medperf/entities/asset.py
+++ b/cli/medperf/entities/asset.py
@@ -64,6 +64,13 @@ def is_local(self) -> bool:
def is_model(self) -> bool:
return True
+ def check_hash(self) -> bool:
+ try:
+ self.get_archive_path()
+ except InvalidEntityError:
+ return False
+ return True
+
@staticmethod
def remote_prefilter(filters: dict) -> callable:
comms_fn = config.comms.get_assets
diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py
index b66183192..54cd69595 100644
--- a/cli/medperf/entities/cube.py
+++ b/cli/medperf/entities/cube.py
@@ -104,6 +104,9 @@ def is_model(self) -> bool:
def is_script(self) -> bool:
return self.parser.is_script_container()
+ def check_hash(self) -> bool:
+ raise NotImplementedError("Hash checking not implemented for container.")
+
@staticmethod
def remote_prefilter(filters: dict):
"""Applies filtering logic that must be done before retrieving remote entities
diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py
index b7d61bbf9..15500b8f7 100644
--- a/cli/medperf/entities/dataset.py
+++ b/cli/medperf/entities/dataset.py
@@ -11,6 +11,7 @@
from medperf.entities.utils import handle_validation_error
from medperf.exceptions import InvalidEntityError
import logging
+from datetime import datetime, timezone
class Dataset(Entity):
@@ -80,6 +81,12 @@ def set_cc_config(self, cc_config: dict):
if "cc" not in self.user_metadata:
self.user_metadata["cc"] = {}
self.user_metadata["cc"]["config"] = cc_config
+ self.user_metadata["cc"]["initialized"] = False
+
+ def set_cc_initialized(self):
+ if not self.is_cc_configured():
+ return
+ self.user_metadata["cc"]["initialized"] = True
def get_cc_policy(self):
cc_values = self.user_metadata.get("cc", {})
@@ -93,6 +100,20 @@ def set_cc_policy(self, cc_policy: dict):
def is_cc_configured(self):
return self.get_cc_config() != {}
+ def is_cc_initialized(self):
+ cc_values = self.user_metadata.get("cc", {})
+ return cc_values.get("initialized", False)
+
+ def set_last_synced(self):
+ if "cc" not in self.user_metadata:
+ self.user_metadata["cc"] = {}
+ self.user_metadata["cc"]["last_synced"] = str(datetime.now(timezone.utc))
+
+ def get_last_synced(self):
+ if "cc" not in self.user_metadata:
+ return
+ return self.user_metadata["cc"].get("last_synced", None)
+
def is_operational(self):
return self.state == "OPERATION"
diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py
index 9022ebd67..832a15b2d 100644
--- a/cli/medperf/entities/event.py
+++ b/cli/medperf/entities/event.py
@@ -63,7 +63,9 @@ def _set_helper_attributes(self):
self.agg_out_logs = os.path.join(
self.path, config.training_out_agg_logs + timestamp
)
- self.col_out_logs = os.path.join(self.path, config.training_out_col_logs)
+ self.col_out_logs = os.path.join(
+ self.path, config.training_out_col_logs + timestamp
+ )
self.out_weights = os.path.join(
self.path, config.training_out_weights + timestamp
)
diff --git a/cli/medperf/entities/model.py b/cli/medperf/entities/model.py
index f92f0a4e5..2934385ac 100644
--- a/cli/medperf/entities/model.py
+++ b/cli/medperf/entities/model.py
@@ -9,6 +9,7 @@
from medperf.account_management import get_medperf_user_data
from medperf.commands.association.utils import get_user_associations
from medperf.entities.utils import handle_validation_error
+from datetime import datetime, timezone
class Model(Entity):
@@ -88,6 +89,12 @@ def set_cc_config(self, cc_config: dict):
if "cc" not in self.user_metadata:
self.user_metadata["cc"] = {}
self.user_metadata["cc"]["config"] = cc_config
+ self.user_metadata["cc"]["initialized"] = False
+
+ def set_cc_initialized(self):
+ if not self.is_cc_configured():
+ return
+ self.user_metadata["cc"]["initialized"] = True
def get_cc_policy(self):
cc_values = self.user_metadata.get("cc", {})
@@ -101,6 +108,30 @@ def set_cc_policy(self, cc_policy: dict):
def is_cc_configured(self):
return self.get_cc_config() != {}
+ def is_cc_initialized(self):
+ cc_values = self.user_metadata.get("cc", {})
+ return cc_values.get("initialized", False)
+
+ def set_last_synced(self):
+ if "cc" not in self.user_metadata:
+ self.user_metadata["cc"] = {}
+ self.user_metadata["cc"]["last_synced"] = str(datetime.now(timezone.utc))
+
+ def get_last_synced(self):
+ if "cc" not in self.user_metadata:
+ return
+ return self.user_metadata["cc"].get("last_synced", None)
+
+ def check_hash(self) -> bool:
+ if self.is_container():
+ return self.container_obj.check_hash()
+ elif self.is_asset():
+ return self.asset_obj.check_hash()
+ else:
+ raise MedperfException(
+ "Internal error: Model is neither a container nor an asset"
+ )
+
@staticmethod
def remote_prefilter(filters: dict) -> callable:
comms_fn = config.comms.get_models
diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py
index 09460f0d4..cd58d7e41 100644
--- a/cli/medperf/entities/schemas.py
+++ b/cli/medperf/entities/schemas.py
@@ -183,6 +183,7 @@ class TrainingExpSchema(MedperfSchema):
plan: dict = {}
metadata: dict = {}
user_metadata: dict = {}
+ aggregator: Optional[int]
class UserSchema(BaseModel):
diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py
index 1167fb855..4047afca6 100644
--- a/cli/medperf/entities/training_exp.py
+++ b/cli/medperf/entities/training_exp.py
@@ -64,6 +64,7 @@ def __init__(self, **kwargs):
self.plan = self._model.plan
self.metadata = self._model.metadata
self.user_metadata = self._model.user_metadata
+ self.aggregator = self._model.aggregator
self._set_helper_attributes()
@@ -138,6 +139,7 @@ def display_dict(self):
"Documentation": self.docs_url,
"Created At": self.created_at,
"FL Container": int(self.fl_mlcube),
+ "Aggregator": int(self.aggregator) if self.aggregator else None,
"Plan": self.plan,
"State": self.state,
"Registered": self.is_registered,
diff --git a/cli/medperf/entities/user.py b/cli/medperf/entities/user.py
index d7b836907..ea74e60cf 100644
--- a/cli/medperf/entities/user.py
+++ b/cli/medperf/entities/user.py
@@ -36,6 +36,16 @@ def set_cc_config(self, cc_config: dict):
if "cc" not in self.metadata:
self.metadata["cc"] = {}
self.metadata["cc"]["config"] = cc_config
+ self.metadata["cc"]["initialized"] = False
+
+ def set_cc_initialized(self):
+ if not self.is_cc_configured():
+ return
+ self.metadata["cc"]["initialized"] = True
def is_cc_configured(self):
return self.get_cc_config() != {}
+
+ def is_cc_initialized(self):
+ cc_values = self.metadata.get("cc", {})
+ return cc_values.get("initialized", False)
diff --git a/cli/medperf/init.py b/cli/medperf/init.py
index eb505e6ff..34c739a06 100644
--- a/cli/medperf/init.py
+++ b/cli/medperf/init.py
@@ -10,6 +10,7 @@
override_storage_config_paths,
)
from medperf.ui.factory import UIFactory
+from medperf.ui.web_ui import WebUI
def initialize(for_webui=False, for_data_monitor=False):
@@ -36,7 +37,14 @@ def initialize(for_webui=False, for_data_monitor=False):
setup_logging(log_file, config.loglevel)
# Setup UI, COMMS
- config.ui = UIFactory.create_ui(ui_class)
+ # Preserve existing WebUI instance when only switching profile (for_webui=True).
+ # Replacing it would break in-flight tasks: worker threads use config.ui and
+ # thread-local task_id; a new WebUI would have empty thread-locals and a new
+ # EventsManager, so aggregator/training logs would get task_id=None and be dropped.
+ if for_webui and isinstance(getattr(config, "ui", None), WebUI):
+ pass # keep existing config.ui (same thread-locals and events_manager)
+ else:
+ config.ui = UIFactory.create_ui(ui_class)
config.comms = CommsFactory.create_comms(config.comms, config.server)
# Setup auth class
diff --git a/cli/medperf/tests/commands/association/test_approve.py b/cli/medperf/tests/commands/association/test_approve.py
index b93834fac..a6f23d50f 100644
--- a/cli/medperf/tests/commands/association/test_approve.py
+++ b/cli/medperf/tests/commands/association/test_approve.py
@@ -21,10 +21,6 @@
{"training_exp_uid": 1, "dataset_uid": 1},
"update_training_dataset_association",
),
- (
- {"training_exp_uid": 1, "aggregator_uid": 1},
- "update_training_aggregator_association",
- ),
],
)
def test_run_calls_correct_comms_method(mocker, comms, ui, kwargs, comms_method):
diff --git a/cli/medperf/tests/commands/association/test_list.py b/cli/medperf/tests/commands/association/test_list.py
index 011bc5da4..77415d982 100644
--- a/cli/medperf/tests/commands/association/test_list.py
+++ b/cli/medperf/tests/commands/association/test_list.py
@@ -21,10 +21,6 @@
{"training_exp": True, "dataset": True},
["training_exp", "dataset"],
),
- (
- {"training_exp": True, "aggregator": True},
- ["training_exp", "aggregator"],
- ),
],
)
def test_run_calls_correct_comms_method(mocker, comms, ui, kwargs, expected_util_args):
diff --git a/cli/medperf/tests/commands/association/test_utils.py b/cli/medperf/tests/commands/association/test_utils.py
index be830d6d2..a42bacb76 100644
--- a/cli/medperf/tests/commands/association/test_utils.py
+++ b/cli/medperf/tests/commands/association/test_utils.py
@@ -6,7 +6,6 @@
@pytest.mark.parametrize("dset_uid", [None, 1])
@pytest.mark.parametrize("mlcube_uid", [None, 1])
-@pytest.mark.parametrize("aggregartor_uid", [None, 1])
@pytest.mark.parametrize("bmk_uid", [None, 1])
@pytest.mark.parametrize("training_exp_uid", [None, 1])
def test_validate_args_fails_if_invalid_arguments(
@@ -15,29 +14,22 @@ def test_validate_args_fails_if_invalid_arguments(
ui,
dset_uid,
mlcube_uid,
- aggregartor_uid,
bmk_uid,
training_exp_uid,
):
# Arrange
- number_of_components_provided = (
- int(dset_uid is not None)
- + int(mlcube_uid is not None)
- + int(aggregartor_uid is not None)
+ number_of_components_provided = int(dset_uid is not None) + int(
+ mlcube_uid is not None
)
number_of_experiments_provided = int(bmk_uid is not None) + int(
training_exp_uid is not None
)
- is_training_component = dset_uid or aggregartor_uid
is_evaluation_component = dset_uid or mlcube_uid
should_succeed = (
number_of_components_provided == 1
and number_of_experiments_provided == 1
- and (
- (training_exp_uid and is_training_component)
- or (bmk_uid and is_evaluation_component)
- )
+ and ((is_evaluation_component and bmk_uid) or (training_exp_uid and dset_uid))
)
# Act & Assert
@@ -48,7 +40,6 @@ def test_validate_args_fails_if_invalid_arguments(
training_exp_uid,
dset_uid,
mlcube_uid,
- aggregartor_uid,
Status.APPROVED.value,
)
else:
@@ -57,7 +48,6 @@ def test_validate_args_fails_if_invalid_arguments(
training_exp_uid,
dset_uid,
mlcube_uid,
- aggregartor_uid,
Status.APPROVED.value,
)
@@ -77,7 +67,6 @@ def test_validate_args_fails_if_invalid_approval_status(
None,
1,
None,
- None,
approval_status,
)
else:
@@ -86,7 +75,6 @@ def test_validate_args_fails_if_invalid_approval_status(
None,
1,
None,
- None,
approval_status,
)
@@ -145,42 +133,3 @@ def test_3_latest_associations():
assert sorted(result, key=lambda x: (x["component"], x["experiment"])) == sorted(
expected, key=lambda x: (x["component"], x["experiment"])
)
-
-
-def test_1_last_component():
- # Arrange
- associations = [
- {"created_at": "2025-04-16 17:38:33", "component": 1, "experiment": 2},
- {"created_at": "2025-04-16 17:34:33", "component": 1, "experiment": 2},
- {"created_at": "2025-04-16 17:32:33", "component": 2, "experiment": 2},
- ]
- expected = [associations[0]]
-
- # Act
- result = utils.get_last_component(associations, "experiment")
-
- # Assert
- # sort them in some way
- assert sorted(result, key=lambda x: (x["component"], x["experiment"])) == sorted(
- expected, key=lambda x: (x["component"], x["experiment"])
- )
-
-
-def test_2_last_component():
- # Arrange
- associations = [
- {"created_at": "2025-04-16 17:38:33", "component": 1, "experiment": 2},
- {"created_at": "2025-04-16 17:34:33", "component": 1, "experiment": 2},
- {"created_at": "2025-04-17 17:32:33", "component": 2, "experiment": 2},
- {"created_at": "2025-04-16 17:32:33", "component": 2, "experiment": 3},
- ]
- expected = [associations[2], associations[3]]
-
- # Act
- result = utils.get_last_component(associations, "experiment")
-
- # Assert
- # sort them in some way
- assert sorted(result, key=lambda x: (x["component"], x["experiment"])) == sorted(
- expected, key=lambda x: (x["component"], x["experiment"])
- )
diff --git a/cli/medperf/ui/cli.py b/cli/medperf/ui/cli.py
index 5367cb49f..ba035e5fe 100644
--- a/cli/medperf/ui/cli.py
+++ b/cli/medperf/ui/cli.py
@@ -75,6 +75,9 @@ def print_error(self, msg: str):
def print_critical(self, msg: str):
self.print_warning(msg)
+ def print_cli_warning(self, msg: str):
+ self.print_warning(msg)
+
def print_warning(self, msg: str):
"""Display a warning message on the command line
diff --git a/cli/medperf/ui/interface.py b/cli/medperf/ui/interface.py
index 42c676f77..6c9fc5745 100644
--- a/cli/medperf/ui/interface.py
+++ b/cli/medperf/ui/interface.py
@@ -13,6 +13,9 @@ def print(self, msg: str = ""):
def print_error(self, msg: str):
"""Display an error message to the interface"""
+ def print_cli_warning(self, msg: str):
+ """Display a warning message only for the CLI interface."""
+
def print_warning(self, msg: str):
"""Display a warning message on the command line"""
diff --git a/cli/medperf/ui/web_ui.py b/cli/medperf/ui/web_ui.py
index 149120833..eaa607ffc 100644
--- a/cli/medperf/ui/web_ui.py
+++ b/cli/medperf/ui/web_ui.py
@@ -1,5 +1,6 @@
from queue import Queue
from contextlib import contextmanager
+import threading
from yaspin import yaspin
import typer
@@ -13,10 +14,20 @@ def __init__(self):
self.responses: Queue[dict] = Queue()
self.is_interactive = False
self.spinner = yaspin(color="green")
- self.task_id = None
+ self._task_id_local = threading.local()
+ self._primary_task_id = None
self.events_manager = EventsManager()
self.global_events_manager = GlobalEventsManager()
+ def _current_task_id(self):
+ """Per-thread task_id so concurrent tasks (e.g. aggregator + training) tag events correctly."""
+ return getattr(self._task_id_local, "task_id", None)
+
+ @property
+ def task_id(self):
+ """Primary task_id for GET /current_task (last task started by a request). Events use _current_task_id()."""
+ return self._primary_task_id
+
def print_error(self, msg: str):
"""Display an error message on the command line
@@ -27,6 +38,10 @@ def print_error(self, msg: str):
msg = typer.style(msg, fg=typer.colors.RED, bold=True)
self._print(msg, "error")
+ def print_cli_warning(self, msg: str):
+ # do nothing
+ pass
+
def print_warning(self, msg: str):
"""Display a warning message on the command line
@@ -44,7 +59,7 @@ def _print(self, msg: str = "", type: str = "print"):
self.set_event(
Event(
- task_id=self.task_id,
+ task_id=self._current_task_id(),
type=type,
message=msg,
interactive=self.is_interactive,
@@ -99,7 +114,7 @@ def text(self, msg: str = ""):
self.set_event(
Event(
- task_id=self.task_id,
+ task_id=self._current_task_id(),
type="text",
message=msg,
interactive=self.is_interactive,
@@ -120,7 +135,7 @@ def prompt(self, msg: str) -> str:
msg = msg.replace(" [Y/n]", "")
self.set_event(
Event(
- task_id=self.task_id,
+ task_id=self._current_task_id(),
type="prompt",
message=msg,
interactive=self.is_interactive,
@@ -182,8 +197,9 @@ def print_critical(self, msg: str):
def set_event(self, event: Event):
self.events_manager.process_event(event)
- def get_event(self, timeout=None):
- return self.events_manager.dequeue_event(timeout=timeout)
+ def get_event(self, task_id=None, timeout=None):
+ """Get next event for the given task_id (required for SSE so each stream gets only its task's events)."""
+ return self.events_manager.dequeue_event(task_id=task_id, timeout=timeout)
def set_response(self, event):
self.responses.put(event)
@@ -196,7 +212,7 @@ def end_task(self, response=None):
self.events_manager.enqueue_event(
Event(
- task_id=self.task_id,
+ task_id=self._current_task_id(),
type="highlight",
message="",
interactive=self.is_interactive,
@@ -207,14 +223,16 @@ def end_task(self, response=None):
self.unset_task_id()
def start_task(self, task_id: str):
+ self._primary_task_id = task_id
self.set_task_id(task_id)
self.events_manager.start_buffering()
def set_task_id(self, task_id):
- self.task_id = task_id
+ self._task_id_local.task_id = task_id
def unset_task_id(self):
- self.task_id = None
+ if hasattr(self._task_id_local, "task_id"):
+ delattr(self._task_id_local, "task_id")
def add_notification(self, message, return_response, url=""):
self.global_events_manager.add_notification(message, return_response, url)
diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py
index 4c96b5910..735fc40b6 100644
--- a/cli/medperf/utils.py
+++ b/cli/medperf/utils.py
@@ -23,6 +23,7 @@
import medperf.config as config
from medperf.exceptions import CleanExit, ExecutionError, InvalidArgumentError
import shlex
+import time
from email_validator import validate_email, EmailNotValidError
@@ -505,9 +506,10 @@ def check_for_updates() -> None:
class spawn_and_kill:
- def __init__(self, cmd, timeout=None, *args, **kwargs):
+ def __init__(self, cmd, timeout=None, task=None, *args, **kwargs):
self.cmd = cmd
self.timeout = timeout
+ self.task = task
self._args = args
self._kwargs = kwargs
self.proc: spawn
@@ -518,16 +520,37 @@ def spawn(*args, **kwargs):
return spawn(*args, **kwargs)
def killpg(self):
- os.killpg(self.pid, signal.SIGINT)
+ """Stop the process group: SIGTERM first, then SIGKILL after a short wait."""
+ try:
+ pgid = os.getpgid(self.pid)
+ except OSError:
+ pgid = self.pid
+ try:
+ os.killpg(pgid, signal.SIGTERM)
+ except OSError:
+ pass
+ time.sleep(1)
+ try:
+ if self.proc.isalive():
+ os.killpg(pgid, signal.SIGKILL)
+ except OSError:
+ pass
def __enter__(self):
self.proc = self.spawn(
self.cmd, timeout=self.timeout, *self._args, **self._kwargs
)
self.pid = self.proc.pid
+
+ if self.task:
+ config.running_containers[self.task] = self
+
return self
def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.task:
+ config.running_containers.pop(self.task, None)
+
if exc_type:
self.exception_occurred = True
# Forcefully kill the process group if any exception occurred, in particular,
@@ -543,11 +566,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False
-def run_command(cmd, timeout=None, output_logs=None):
+def run_command(cmd, timeout=None, output_logs=None, task=None):
logging.debug(f"Command as list, to be run: {cmd}")
command_as_str = shlex.join(cmd)
logging.debug(f"Running command: {command_as_str}")
- with spawn_and_kill(command_as_str, timeout=timeout) as proc_wrapper:
+ with spawn_and_kill(command_as_str, timeout=timeout, task=task) as proc_wrapper:
proc = proc_wrapper.proc
proc_out = combine_proc_sp_text(proc)
diff --git a/cli/medperf/web_ui/aggregators/routes.py b/cli/medperf/web_ui/aggregators/routes.py
new file mode 100644
index 000000000..09b78d9d9
--- /dev/null
+++ b/cli/medperf/web_ui/aggregators/routes.py
@@ -0,0 +1,204 @@
+import logging
+import threading
+
+from fastapi import APIRouter, Depends, Form, Request
+from fastapi.responses import HTMLResponse, JSONResponse
+
+import medperf.config as config
+from medperf.account_management import get_medperf_user_data
+from medperf.commands.aggregator.submit import SubmitAggregator
+from medperf.commands.certificate.server_certificate import GetServerCertificate
+from medperf.commands.aggregator.run import StartAggregator
+from medperf.entities.aggregator import Aggregator
+from medperf.entities.cube import Cube
+from medperf.web_ui.common import (
+ check_user_api,
+ check_user_ui,
+ initialize_state_task,
+ reset_state_task,
+ templates,
+)
+
+router = APIRouter()
+logger = logging.getLogger(__name__)
+
+
+@router.get("/register/ui", response_class=HTMLResponse)
+def register_aggregator_ui(
+ request: Request,
+ current_user: bool = Depends(check_user_ui),
+):
+ containers = Cube.all()
+ containers = [{"id": c.id, "name": c.name} for c in containers]
+ return templates.TemplateResponse(
+ "aggregators/register_aggregator.html",
+ {"request": request, "containers": containers},
+ )
+
+
+@router.post("/register", response_class=JSONResponse)
+def register_aggregator(
+ request: Request,
+ name: str = Form(...),
+ address: str = Form(...),
+ port: int = Form(...),
+ aggregation_mlcube: int = Form(...),
+ current_user: bool = Depends(check_user_api),
+):
+ initialize_state_task(request, task_name="register_aggregator")
+ return_response = {"status": "", "error": "", "entity_id": None}
+ aggregator_id = None
+ try:
+ aggregator_id = SubmitAggregator.run(
+ name=name,
+ address=address.strip(),
+ port=port,
+ aggregation_mlcube=aggregation_mlcube,
+ )
+ return_response["status"] = "success"
+ return_response["entity_id"] = aggregator_id
+ notification_message = "Aggregator successfully registered"
+ except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "Failed to register aggregator"
+ logger.exception(exp)
+
+ config.ui.end_task(return_response)
+ reset_state_task(request)
+ redirect_url = (
+ f"/aggregators/ui/display/{aggregator_id}"
+ if aggregator_id
+ else "/aggregators/register/ui"
+ )
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=redirect_url,
+ )
+ return return_response
+
+
+@router.get("/ui", response_class=HTMLResponse)
+def aggregators_ui(
+ request: Request,
+ mine_only: bool = False,
+ current_user: bool = Depends(check_user_ui),
+):
+ filters = {}
+ my_user_id = get_medperf_user_data()["id"]
+ if mine_only:
+ filters["owner"] = my_user_id
+
+ aggregators = Aggregator.all(filters=filters)
+ aggregators = sorted(aggregators, key=lambda x: x.created_at or "", reverse=True)
+ mine_aggs = [a for a in aggregators if a.owner == my_user_id]
+ other_aggs = [a for a in aggregators if a.owner != my_user_id]
+ aggregators = mine_aggs + other_aggs
+
+ return templates.TemplateResponse(
+ "aggregators/aggregators.html",
+ {"request": request, "aggregators": aggregators, "mine_only": mine_only},
+ )
+
+
+@router.get("/ui/display/{aggregator_id}", response_class=HTMLResponse)
+def aggregator_detail_ui(
+ request: Request,
+ aggregator_id: int,
+ current_user: bool = Depends(check_user_ui),
+):
+ my_user_id = get_medperf_user_data()["id"]
+ entity = Aggregator.get(aggregator_id)
+ owner = entity.owner == my_user_id
+ # Training experiments that have this aggregator set (reverse relation)
+ experiments_using_aggregator = entity.get_training_experiments()
+
+ return templates.TemplateResponse(
+ "aggregators/aggregator_detail.html",
+ {
+ "request": request,
+ "entity": entity,
+ "experiments_using_aggregator": experiments_using_aggregator,
+ "owner": owner,
+ },
+ )
+
+
+@router.post("/get_server_certificate", response_class=JSONResponse)
+def get_server_certificate(
+ request: Request,
+ aggregator_id: int = Form(...),
+ current_user: bool = Depends(check_user_api),
+):
+ initialize_state_task(request, task_name="aggregator_get_server_cert")
+ return_response = {"status": "", "error": ""}
+ try:
+ GetServerCertificate.run(aggregator_id=aggregator_id)
+ return_response["status"] = "success"
+ notification_message = "Server certificate retrieved successfully"
+ except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "Failed to get server certificate"
+ logger.exception(exp)
+
+ config.ui.end_task(return_response)
+ reset_state_task(request)
+ redirect_url = f"/aggregators/ui/display/{aggregator_id}"
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=redirect_url,
+ )
+ return return_response
+
+
+def _run_aggregator_worker(
+ request: Request,
+ training_exp_id: int,
+ aggregator_id: int,
+ task_id: str,
+):
+ redirect_url = f"/aggregators/ui/display/{aggregator_id}"
+ return_response = {"status": "", "error": ""}
+ notification_message = "Aggregator run started successfully"
+ config.ui.set_task_id(task_id)
+ try:
+ StartAggregator.run(training_exp_id=training_exp_id, publish_on="0.0.0.0")
+ return_response["status"] = "success"
+ except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "An error occurred while running the aggregator"
+ logger.exception(exp)
+
+ config.ui.end_task(return_response)
+ reset_state_task(request, task_id)
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=redirect_url,
+ )
+
+
+@router.post("/run", response_class=JSONResponse)
+def run_aggregator(
+ request: Request,
+ aggregator_id: int = Form(...),
+ training_exp_id: int = Form(...),
+ current_user: bool = Depends(check_user_api),
+):
+ agg_meta = config.comms.get_experiment_aggregator(training_exp_id)
+ if not agg_meta or agg_meta.get("id") != aggregator_id:
+ raise ValueError("Selected training experiment does not use this aggregator")
+
+ task_id = initialize_state_task(request, task_name="start_aggregator")
+
+ threading.Thread(
+ target=_run_aggregator_worker,
+ args=(request, training_exp_id, aggregator_id, task_id),
+ daemon=True,
+ ).start()
+
+ return {"status": "started", "error": ""}
diff --git a/cli/medperf/web_ui/api/routes.py b/cli/medperf/web_ui/api/routes.py
index 10f76c296..448ac89e7 100644
--- a/cli/medperf/web_ui/api/routes.py
+++ b/cli/medperf/web_ui/api/routes.py
@@ -4,6 +4,7 @@
from fastapi import APIRouter, HTTPException, Form, Depends
from fastapi.responses import JSONResponse
+import medperf.config as config
from medperf.exceptions import InvalidArgumentError
from medperf.web_ui.common import check_user_api
from medperf.utils import sanitize_path
@@ -11,6 +12,29 @@
router = APIRouter()
+@router.get("/running_tasks", response_class=JSONResponse)
+def get_running_tasks(current_user: bool = Depends(check_user_api)):
+ tasks = list(config.running_containers.keys())
+ return {"tasks": tasks}
+
+
+@router.post("/stop_task", response_class=JSONResponse)
+def stop_task(
+ task_name: str = Form(...),
+ current_user: bool = Depends(check_user_api),
+):
+ wrapper = config.running_containers.get(task_name)
+ if wrapper is None:
+ raise HTTPException(
+ status_code=404, detail=f"No running task named '{task_name}'"
+ )
+ try:
+ wrapper.killpg()
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+ return {"status": "ok"}
+
+
# TODO: close with token and list in documentation
@router.post("/browse", response_class=JSONResponse)
def browse_directory(
diff --git a/cli/medperf/web_ui/app.py b/cli/medperf/web_ui/app.py
index 45a65b1d4..863214505 100644
--- a/cli/medperf/web_ui/app.py
+++ b/cli/medperf/web_ui/app.py
@@ -5,6 +5,7 @@
from medperf.logging.utils import log_machine_details
import typer
+from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import FastAPI, Request
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
@@ -18,6 +19,8 @@
from medperf.web_ui.containers.routes import router as containers_router
from medperf.web_ui.models.routes import router as models_router
from medperf.web_ui.assets.routes import router as assets_router
+from medperf.web_ui.training.routes import router as training_router
+from medperf.web_ui.aggregators.routes import router as aggregators_router
from medperf.web_ui.yaml_fetch.routes import router as yaml_fetch_router
from medperf.web_ui.api.routes import router as api_router
from medperf.web_ui.security_check import router as login_router
@@ -25,15 +28,37 @@
from medperf.web_ui.medperf_login import router as medperf_login
from medperf.web_ui.settings import router as settings_router
from medperf.web_ui.auth import wrap_openapi, NotAuthenticatedException, security_token
-from medperf.web_ui.schemas import WebUITask
+
+JS_VERSION = "1.0.0"
+
+UI_MODE_COOKIE = "medperf-mode"
+UI_MODE_TRAINING = "training"
+UI_MODE_EVALUATION = "evaluation"
+
+
+class NavModeMiddleware(BaseHTTPMiddleware):
+ """Set request.app.state.ui_mode from cookie so templates and routes can use it."""
+
+ async def dispatch(self, request, call_next):
+ request.app.state.ui_mode = request.cookies.get(
+ UI_MODE_COOKIE, UI_MODE_EVALUATION
+ )
+ if request.app.state.ui_mode not in (UI_MODE_EVALUATION, UI_MODE_TRAINING):
+ request.app.state.ui_mode = UI_MODE_EVALUATION
+ return await call_next(request)
+
web_app = FastAPI()
+web_app.add_middleware(NavModeMiddleware)
+
web_app.include_router(datasets_router, prefix="/datasets")
web_app.include_router(benchmarks_router, prefix="/benchmarks")
web_app.include_router(containers_router, prefix="/containers")
web_app.include_router(models_router, prefix="/models")
web_app.include_router(assets_router, prefix="/assets")
+web_app.include_router(training_router, prefix="/training")
+web_app.include_router(aggregators_router, prefix="/aggregators")
web_app.include_router(yaml_fetch_router)
web_app.include_router(api_router, prefix="/api")
web_app.include_router(login_router)
@@ -43,7 +68,9 @@
static_folder_path = Path(resources.files("medperf.web_ui")) / "static"
-web_app.mount("/static", StaticFiles(directory=static_folder_path), name="static")
+web_app.mount(
+ f"/static/v{JS_VERSION}", StaticFiles(directory=static_folder_path), name="static"
+)
web_app.add_exception_handler(Exception, custom_exception_handler)
@@ -52,7 +79,7 @@
@web_app.on_event("startup")
def startup_event():
- web_app.state.task = WebUITask()
+ web_app.state.active_tasks = {} # task_id -> WebUITask (multiple tasks can run)
web_app.state.old_tasks = [] # List of [schemas.WebUITask]
web_app.state.task_running = False
web_app.state.MAXLOGMESSAGES = config.webui_max_log_messages
@@ -70,6 +97,11 @@ def startup_event():
"interval": 0,
}
+ # Set default UI mode to evaluation on startup, will be updated by NavModeMiddleware on each request based on cookie
+ web_app.state.ui_mode = UI_MODE_EVALUATION
+ web_app.state.TRAINING_MODE = UI_MODE_TRAINING
+ web_app.state.EVALUATION_MODE = UI_MODE_EVALUATION
+
# continue setup logging
host_props = {**web_app.state.host_props, "security_token": security_token}
with open(config.webui_host_props, "w") as f:
@@ -94,10 +126,24 @@ def not_authenticated_exception_handler(
@web_app.get("/", include_in_schema=False)
-def read_root():
+def read_root(request: Request):
+ if request.app.state.ui_mode == UI_MODE_TRAINING:
+ return RedirectResponse(url="/training/ui")
return RedirectResponse(url="/benchmarks/ui")
+@web_app.get("/set_mode", include_in_schema=False)
+def set_mode(request: Request, mode: str = "evaluation"):
+ """Set nav mode (evaluation | training) via cookie and redirect to the default page for that mode."""
+ if mode == UI_MODE_TRAINING:
+ response = RedirectResponse(url="/training/ui")
+ response.set_cookie(key=UI_MODE_COOKIE, value=UI_MODE_TRAINING, path="/")
+ else:
+ response = RedirectResponse(url="/benchmarks/ui")
+ response.set_cookie(key=UI_MODE_COOKIE, value=UI_MODE_EVALUATION, path="/")
+ return response
+
+
app = typer.Typer()
diff --git a/cli/medperf/web_ui/assets/routes.py b/cli/medperf/web_ui/assets/routes.py
index 99deb5c9f..1aa8e009f 100644
--- a/cli/medperf/web_ui/assets/routes.py
+++ b/cli/medperf/web_ui/assets/routes.py
@@ -58,7 +58,7 @@ def register_asset(
):
initialize_state_task(request, task_name="asset_registration")
- return_response = {"status": "", "error": "", "asset_id": None}
+ return_response = {"status": "", "error": "", "asset_id": None, "entity_id": None}
asset_id = None
try:
asset_id = SubmitAsset.run(
@@ -69,6 +69,7 @@ def register_asset(
)
return_response["status"] = "success"
return_response["asset_id"] = asset_id
+ return_response["entity_id"] = asset_id
notification_message = "Asset successfully registered"
except Exception as exp:
return_response["status"] = "failed"
diff --git a/cli/medperf/web_ui/benchmarks/routes.py b/cli/medperf/web_ui/benchmarks/routes.py
index 2f84e6431..02087a3a0 100644
--- a/cli/medperf/web_ui/benchmarks/routes.py
+++ b/cli/medperf/web_ui/benchmarks/routes.py
@@ -197,8 +197,8 @@ def register_benchmark(
"data_evaluator_mlcube": evaluator_container,
"state": "OPERATION",
}
- initialize_state_task(request, task_name="benchmark_registration")
- return_response = {"status": "", "error": "", "benchmark_id": None}
+ initialize_state_task(request, task_name="register_benchmark")
+ return_response = {"status": "", "error": "", "entity_id": None}
benchmark_id = None
try:
benchmark_id = SubmitBenchmark.run(
@@ -207,7 +207,7 @@ def register_benchmark(
skip_compatibility_tests=skip_compatibility_tests,
)
return_response["status"] = "success"
- return_response["benchmark_id"] = benchmark_id
+ return_response["entity_id"] = benchmark_id
notification_message = "Benchmark successfully registered!"
except Exception as exp:
return_response["status"] = "failed"
diff --git a/cli/medperf/web_ui/common.py b/cli/medperf/web_ui/common.py
index e6e982c8b..796dc70e2 100644
--- a/cli/medperf/web_ui/common.py
+++ b/cli/medperf/web_ui/common.py
@@ -27,36 +27,51 @@
)
from medperf.web_ui.schemas import WebUITask
-from medperf.web_ui.utils import generate_uuid
-
templates_folder_path = Path(resources.files("medperf.web_ui")) / "templates"
templates = Jinja2Templates(directory=templates_folder_path)
logger = logging.getLogger(__name__)
-ALLOWED_PATHS = ["/events", "/notifications", "/current_task", "/fetch-yaml"]
+ALLOWED_PATHS = [
+ "/events",
+ "/notifications",
+ "/api/running_tasks",
+ "/api/stop_task",
+ "/aggregators/run",
+ "/datasets/start_training",
+ "/settings/activate_profile",
+]
def initialize_state_task(request: Request, task_name: str) -> str:
form_data = dict(anyio.from_thread.run(lambda: request.form()))
- new_task_id = generate_uuid()
+ new_task_id = task_name
config.ui.start_task(new_task_id)
- request.app.state.task = WebUITask(
- id=new_task_id, name=task_name, running=True, formData=form_data
- )
+ task = WebUITask(id=new_task_id, name=task_name, running=True, formData=form_data)
+ request.app.state.active_tasks[task_name] = task
+ request.app.state.task = task
request.app.state.task_running = True
return new_task_id
-def reset_state_task(request: Request):
- current_task = request.app.state.task
+def reset_state_task(request: Request, task_id: str = None):
+ if task_id is None:
+ task_id = getattr(request.app.state.task, "id", None)
+ if task_id is None:
+ return
+ active_tasks = request.app.state.active_tasks
+ current_task = active_tasks.pop(task_id, None)
+ if current_task is None:
+ return
current_task.set_running(False)
+
+ if not active_tasks:
+ request.app.state.task_running = False
+
if len(request.app.state.old_tasks) == 10:
request.app.state.old_tasks.pop(0)
request.app.state.old_tasks.append(current_task)
- request.app.state.task = WebUITask()
- request.app.state.task_running = False
def custom_exception_handler(request: Request, exc: Exception):
diff --git a/cli/medperf/web_ui/containers/routes.py b/cli/medperf/web_ui/containers/routes.py
index 1ded38765..647ff1be6 100644
--- a/cli/medperf/web_ui/containers/routes.py
+++ b/cli/medperf/web_ui/containers/routes.py
@@ -113,9 +113,9 @@ def register_container(
decryption_file: str = Form(None),
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="container_registration")
+ initialize_state_task(request, task_name="register_container")
- return_response = {"status": "", "error": "", "container_id": None}
+ return_response = {"status": "", "error": "", "entity_id": None}
container_info = {
"name": name,
"additional_files_tarball_url": additional_file,
@@ -131,7 +131,7 @@ def register_container(
decryption_key=decryption_file,
)
return_response["status"] = "success"
- return_response["container_id"] = container_id
+ return_response["entity_id"] = container_id
notification_message = "Container successfully registered"
except Exception as exp:
return_response["status"] = "failed"
diff --git a/cli/medperf/web_ui/datasets/routes.py b/cli/medperf/web_ui/datasets/routes.py
index 17abfcb82..a7b9131b0 100644
--- a/cli/medperf/web_ui/datasets/routes.py
+++ b/cli/medperf/web_ui/datasets/routes.py
@@ -1,6 +1,7 @@
import os
import logging
-from typing import List
+import threading
+from typing import List, Optional
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi import Request, APIRouter, Depends, Form
@@ -24,6 +25,10 @@
from medperf.entities.benchmark import Benchmark
from medperf.entities.execution import Execution
from medperf.entities.model import Model
+from medperf.entities.training_exp import TrainingExp
+from medperf.commands.association.utils import get_user_associations
+from medperf.commands.dataset.associate_training import AssociateTrainingDataset
+from medperf.commands.dataset.train import TrainingExecution
from medperf.web_ui.common import (
templates,
check_user_ui,
@@ -31,6 +36,7 @@
initialize_state_task,
reset_state_task,
)
+from medperf.web_ui.utils import get_container_type
logger = logging.getLogger(__name__)
@@ -47,9 +53,7 @@ def datasets_ui(
my_user_id = get_medperf_user_data()["id"]
if mine_only:
filters["owner"] = my_user_id
- datasets = Dataset.all(
- filters=filters,
- )
+ datasets = Dataset.all(filters=filters)
datasets = sorted(datasets, key=lambda x: x.created_at, reverse=True)
# sort by (mine recent) (mine oldish), (other recent), (other oldish)
@@ -58,7 +62,11 @@ def datasets_ui(
datasets = mine_datasets + other_datasets
return templates.TemplateResponse(
"dataset/datasets.html",
- {"request": request, "datasets": datasets, "mine_only": mine_only},
+ {
+ "request": request,
+ "datasets": datasets,
+ "mine_only": mine_only,
+ },
)
@@ -68,115 +76,156 @@ def dataset_detail_ui( # noqa
dataset_id: int,
current_user: bool = Depends(check_user_ui),
):
+ user_obj = get_medperf_user_object()
+
dataset = Dataset.get(dataset_id)
dataset.read_report()
dataset.read_statistics()
prep_cube = Cube.get(cube_uid=dataset.data_preparation_mlcube)
- benchmark_assocs = Dataset.get_benchmarks_associations(dataset_uid=dataset_id)
- benchmark_associations = {}
- for assoc in benchmark_assocs:
- benchmark_associations[assoc["benchmark"]] = assoc
- # benchmark_associations = sort_associations_display(benchmark_associations)
-
- # Get all relevant benchmarks for making an association
- benchmarks = Benchmark.all()
- valid_benchmarks = {
- b.id: b
- for b in benchmarks
- if b.data_preparation_mlcube == dataset.data_preparation_mlcube
- }
- for benchmark in valid_benchmarks:
- ref_model_id = valid_benchmarks[benchmark].reference_model
- valid_benchmarks[benchmark].reference_model = Model.get(ref_model_id)
- dataset_is_operational = dataset.is_operational()
- dataset_is_prepared = dataset.is_ready() or dataset_is_operational
- approved_benchmarks = [
- i
- for i in benchmark_associations
- if benchmark_associations[i]["approval_status"] == "APPROVED"
- ]
- user_obj = get_medperf_user_object()
my_user_id = user_obj.id
is_owner = my_user_id == dataset.owner
- dataset_hash_mismatch = None
- if dataset_is_operational and is_owner:
- dataset_hash_mismatch = not dataset.check_hash()
-
- # Get all results
- results = []
- if benchmark_assocs:
- user_id = user_obj.id
- results = Execution.all(filters={"owner": user_id})
- results = filter_latest_executions(results)
-
- # Fetch models associated with each benchmark
- benchmark_models = {}
- for assoc in benchmark_assocs:
- if assoc["approval_status"] != "APPROVED":
- continue # if association is not approved we cannot list its models
- models_uids = Benchmark.get_models_uids(benchmark_uid=assoc["benchmark"])
- models = [Model.get(model_uid) for model_uid in models_uids]
- benchmark_models[assoc["benchmark"]] = models
- for model in models + [valid_benchmarks[assoc["benchmark"]].reference_model]:
- model._encrypted = model.is_encrypted()
- model._requires_cc = model.requires_cc()
- if model._encrypted:
- model.access_status = check_access_to_container(model.container.id)
- if model._requires_cc:
- if not dataset.is_cc_configured():
- reason = "Your dataset is not configured for CC yet"
- can_run = False
- elif not model.is_cc_configured():
- reason = "Wait for model owner to configure their CC settings"
- can_run = False
- elif not user_obj.is_cc_configured():
- reason = (
- "You haven't configured your workload run settings for CC yet"
- )
- can_run = False
- else:
- reason = ""
- can_run = True
- model.cc_run_status = {"can_run": can_run, "reason": reason}
- model.result = None
- for result in results:
- if (
- result.benchmark == assoc["benchmark"]
- and result.dataset == dataset_id
- and result.model == model.id
- ):
- model.result = result.todict()
- model.result["results_exist"] = (
- result.is_executed() or result.finalized
- )
- if model.result["results_exist"]:
- model.result["results"] = result.read_results()
-
+ dataset_is_operational = dataset.is_operational()
+ dataset_is_prepared = dataset.is_ready() or dataset_is_operational
report_exists = os.path.exists(dataset.report_path)
+ ui_mode = request.app.state.ui_mode
cc_config_defaults = dataset.get_cc_config()
cc_configured = dataset.is_cc_configured()
+ cc_initialized = dataset.is_cc_initialized()
+ cc_last_synced = dataset.get_last_synced()
+ context = {
+ "request": request,
+ "dataset": dataset,
+ "prep_cube": prep_cube,
+ "dataset_is_prepared": dataset_is_prepared,
+ "dataset_is_operational": dataset_is_operational,
+ "is_owner": is_owner,
+ "report_exists": report_exists,
+ "cc_config_defaults": cc_config_defaults,
+ "cc_configured": cc_configured,
+ "cc_initialized": cc_initialized,
+ "cc_last_synced": cc_last_synced,
+ }
- return templates.TemplateResponse(
- "dataset/dataset_detail.html",
- {
- "request": request,
- "dataset": dataset,
- "prep_cube": prep_cube,
- "dataset_is_prepared": dataset_is_prepared,
- "dataset_is_operational": dataset_is_operational,
- "dataset_hash_mismatch": dataset_hash_mismatch,
- "benchmark_associations": benchmark_associations, #
- "benchmarks": valid_benchmarks, # Benchmarks that can be associated
- "benchmark_models": benchmark_models, # Pass associated models without status
- "approved_benchmarks": approved_benchmarks,
- "is_owner": is_owner,
- "report_exists": report_exists,
- "cc_config_defaults": cc_config_defaults,
- "cc_configured": cc_configured,
- },
- )
+ if ui_mode == request.app.state.EVALUATION_MODE:
+ benchmark_assocs = Dataset.get_benchmarks_associations(dataset_uid=dataset_id)
+ benchmark_associations = {}
+ for assoc in benchmark_assocs:
+ benchmark_associations[assoc["benchmark"]] = assoc
+ # benchmark_associations = sort_associations_display(benchmark_associations)
+
+ # Get all relevant benchmarks for making an association
+ benchmarks = Benchmark.all()
+ valid_benchmarks = {
+ b.id: b
+ for b in benchmarks
+ if b.data_preparation_mlcube == dataset.data_preparation_mlcube
+ }
+ approved_benchmarks = [
+ i
+ for i in benchmark_associations
+ if benchmark_associations[i]["approval_status"] == "APPROVED"
+ ]
+ # Get all results
+ results = []
+ if benchmark_assocs:
+ user_id = user_obj.id
+ results = Execution.all(filters={"owner": user_id})
+ results = filter_latest_executions(results)
+
+ # Fetch models associated with each benchmark
+ benchmark_models = {}
+ for assoc in benchmark_assocs:
+ if assoc["approval_status"] != "APPROVED":
+ continue # if association is not approved we cannot list its models
+ models_uids = Benchmark.get_models_uids(benchmark_uid=assoc["benchmark"])
+ reference_model_id = valid_benchmarks[assoc["benchmark"]].reference_model
+ models_uids.insert(0, reference_model_id)
+ models = [Model.get(model_uid) for model_uid in models_uids]
+ # check if any model requires cc. if yes, remove the reference model
+ for model in models:
+ if model.requires_cc():
+ models.pop(0)
+ break
+ benchmark_models[assoc["benchmark"]] = models
+ for model in benchmark_models[assoc["benchmark"]]:
+ model._encrypted = model.is_encrypted()
+ model._requires_cc = model.requires_cc()
+ if model._encrypted:
+ model.access_status = check_access_to_container(model.container.id)
+ if model._requires_cc:
+ if not dataset.is_cc_initialized():
+ reason = "Your dataset is not configured for CC yet"
+ can_run = False
+ elif not model.is_cc_initialized():
+ reason = "Wait for model owner to configure their CC settings"
+ can_run = False
+ elif not user_obj.is_cc_initialized():
+ reason = "You haven't configured your workload run settings for CC yet"
+ can_run = False
+ else:
+ reason = ""
+ can_run = True
+ model.cc_run_status = {"can_run": can_run, "reason": reason}
+ model.result = None
+ for result in results:
+ if (
+ result.benchmark == assoc["benchmark"]
+ and result.dataset == dataset_id
+ and result.model == model.id
+ ):
+ model.result = result.todict()
+ model.result["results_exist"] = (
+ result.is_executed() or result.finalized
+ )
+ if model.result["results_exist"]:
+ model.result["results"] = result.read_results()
+
+ context.update(
+ {
+ "benchmark_associations": benchmark_associations,
+ "benchmarks": valid_benchmarks,
+ "benchmark_models": benchmark_models,
+ "approved_benchmarks": approved_benchmarks,
+ }
+ )
+
+ else:
+ training_associations = {}
+ available_training_experiments = []
+ try:
+ user_training_assocs = get_user_associations(
+ experiment_type="training_exp", component_type="dataset"
+ )
+ for a in user_training_assocs:
+ if a.get("dataset") == dataset_id:
+ training_associations[a["training_exp"]] = a
+ all_training = TrainingExp.all()
+ available_training_experiments = [
+ t
+ for t in all_training
+ if t.data_preparation_mlcube == dataset.data_preparation_mlcube
+ ]
+ except Exception as e:
+ logger.warning("Could not load training associations: %s", e)
+
+ experiments_by_id = {e.id: e for e in available_training_experiments}
+ for exp_id in training_associations:
+ if exp_id not in experiments_by_id:
+ try:
+ experiments_by_id[exp_id] = TrainingExp.get(exp_id)
+ except Exception:
+ pass
+ context.update(
+ {
+ "training_associations": training_associations,
+ "available_training_experiments": available_training_experiments,
+ "experiments_by_id": experiments_by_id,
+ }
+ )
+
+ return templates.TemplateResponse("dataset/dataset_detail.html", context)
@router.get("/register/ui", response_class=HTMLResponse)
@@ -184,19 +233,36 @@ def create_dataset_ui(
request: Request,
current_user: bool = Depends(check_user_ui),
):
- # Fetch the list of benchmarks to populate the benchmark dropdown
- benchmarks = Benchmark.all()
- # Render the dataset creation form with the list of benchmarks
- return templates.TemplateResponse(
- "dataset/register_dataset.html", {"request": request, "benchmarks": benchmarks}
- )
+ ui_mode = request.app.state.ui_mode
+ context = {"request": request}
+
+ if ui_mode == request.app.state.EVALUATION_MODE:
+ benchmarks = Benchmark.all()
+ context["benchmarks"] = benchmarks
+ else:
+ my_containers = Cube.all()
+ containers = []
+ for container in my_containers:
+ container_obj = {
+ "id": container.id,
+ "name": container.name,
+ "type": get_container_type(container),
+ }
+ containers.append(container_obj)
+ data_prep_containers = [
+ c for c in containers if c["type"] == "data-prep-container"
+ ]
+ context["data_prep_containers"] = data_prep_containers
+
+ return templates.TemplateResponse("dataset/register_dataset.html", context)
@router.post("/register/", response_class=JSONResponse)
def register_dataset(
request: Request,
submit_as_prepared: bool = Form(False),
- benchmark: int = Form(...),
+ benchmark: Optional[int] = Form(None),
+ prep_cube_uid: Optional[int] = Form(None),
name: str = Form(...),
description: str = Form(...),
location: str = Form(...),
@@ -204,13 +270,13 @@ def register_dataset(
labels_path: str = Form(...),
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="dataset_registration")
- return_response = {"status": "", "dataset_id": None, "error": ""}
+ initialize_state_task(request, task_name="register_dataset")
+ return_response = {"status": "", "entity_id": None, "error": ""}
dataset_id = None
try:
dataset_id = DataCreation.run(
benchmark_uid=benchmark,
- prep_cube_uid=None,
+ prep_cube_uid=prep_cube_uid,
data_path=data_path,
labels_path=labels_path,
metadata_path=None,
@@ -221,7 +287,7 @@ def register_dataset(
submit_as_prepared=bool(submit_as_prepared),
)
return_response["status"] = "success"
- return_response["dataset_id"] = dataset_id
+ return_response["entity_id"] = dataset_id
notification_message = "Dataset successfully registered"
except Exception as exp:
return_response["status"] = "failed"
@@ -249,7 +315,7 @@ def prepare(
dataset_id: int = Form(...),
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="dataset_preparation")
+ initialize_state_task(request, task_name="prepare")
return_response = {"status": "", "dataset_id": None, "error": ""}
try:
@@ -333,6 +399,84 @@ def associate(
return return_response
+@router.post("/associate_training", response_class=JSONResponse)
+def associate_training(
+ request: Request,
+ dataset_id: int = Form(...),
+ training_exp_id: int = Form(...),
+ current_user: bool = Depends(check_user_api),
+):
+ initialize_state_task(request, task_name="dataset_training_association")
+ return_response = {"status": "", "error": ""}
+ try:
+ AssociateTrainingDataset.run(
+ data_uid=dataset_id,
+ training_exp_uid=training_exp_id,
+ approved=True,
+ )
+ return_response["status"] = "success"
+ notification_message = (
+ "Successfully requested dataset association with training experiment"
+ )
+ except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "Failed to request association with training experiment"
+ logger.exception(exp)
+
+ config.ui.end_task(return_response)
+ reset_state_task(request)
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=f"/datasets/ui/display/{dataset_id}",
+ )
+ return return_response
+
+
+def _run_training_worker(
+ request: Request, training_exp_id: int, dataset_id: int, task_id: str
+):
+ redirect_url = f"/datasets/ui/display/{dataset_id}"
+ return_response = {"status": "", "error": ""}
+ notification_message = "Training successfully finished"
+ config.ui.set_task_id(task_id)
+ try:
+ TrainingExecution.run(training_exp_id=training_exp_id, data_uid=dataset_id)
+ return_response["status"] = "success"
+ except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "An error occurred during training execution"
+ logger.exception(exp)
+
+ config.ui.end_task(return_response)
+ reset_state_task(request, task_id)
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=redirect_url,
+ )
+
+
+@router.post("/start_training", response_class=JSONResponse)
+def start_training(
+ request: Request,
+ dataset_id: int = Form(...),
+ training_exp_id: int = Form(...),
+ current_user: bool = Depends(check_user_api),
+):
+ task_id = initialize_state_task(request, task_name="start_training")
+
+ threading.Thread(
+ target=_run_training_worker,
+ args=(request, training_exp_id, dataset_id, task_id),
+ daemon=True,
+ ).start()
+
+ return {"status": "started", "error": ""}
+
+
@router.post("/run", response_class=JSONResponse)
def run(
request: Request,
@@ -342,7 +486,7 @@ def run(
run_all: bool = Form(...),
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="benchmark_run")
+ initialize_state_task(request, task_name="run_benchmark")
return_response = {"status": "", "error": ""}
try:
@@ -377,7 +521,7 @@ def submit_result(
result_id: str = Form(...),
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="result_submit")
+ initialize_state_task(request, task_name="submit_result")
return_response = {"status": "", "error": ""}
try:
@@ -411,9 +555,7 @@ def export_dataset_ui(
dataset.read_statistics()
prep_cube = Cube.get(cube_uid=dataset.data_preparation_mlcube)
dataset_is_operational = dataset.state == "OPERATION"
- dataset_is_prepared = ( # TODO: should we use submitted_as_prepared here?
- dataset.submitted_as_prepared or dataset.is_ready() or dataset_is_operational
- )
+ dataset_is_prepared = dataset.is_ready() or dataset_is_operational
report_exists = os.path.exists(dataset.report_path)
return templates.TemplateResponse(
@@ -437,7 +579,7 @@ def export_dataset(
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="dataset_export")
+ initialize_state_task(request, task_name="export_dataset")
return_response = {"status": "", "error": "", "dataset_id": dataset_id}
try:
@@ -481,7 +623,7 @@ def import_dataset(
current_user: bool = Depends(check_user_api),
):
- initialize_state_task(request, task_name="dataset_import")
+ initialize_state_task(request, task_name="import_dataset")
return_response = {"status": "", "error": "", "dataset_id": dataset_id}
try:
@@ -510,7 +652,7 @@ def import_dataset(
def edit_cc_config(
request: Request,
entity_id: int = Form(...),
- require_cc: bool = Form(False),
+ configure_cc: bool = Form(False),
project_id: str = Form(""),
project_number: str = Form(""),
bucket: str = Form(""),
@@ -531,7 +673,7 @@ def edit_cc_config(
"wip": wip,
"wip_provider": wip_provider,
}
- if not require_cc:
+ if not configure_cc:
args = {}
initialize_state_task(request, task_name="data_update_cc_config")
return_response = {"status": "", "error": ""}
@@ -557,12 +699,27 @@ def edit_cc_config(
@router.post("/sync_cc_policy", response_class=JSONResponse)
def sync_cc_policy(
+ request: Request,
entity_id: int = Form(...),
current_user: bool = Depends(check_user_api),
):
+ initialize_state_task(request, task_name="data_update_cc_policy")
+ return_response = {"status": "", "error": ""}
try:
DatasetUpdateCCPolicy.run(entity_id)
- return {"status": "success", "error": ""}
+ return_response["status"] = "success"
+ notification_message = "Successfully updated dataset CC policy!"
except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "Failed to update dataset CC policy"
logger.exception(exp)
- return {"status": "failed", "error": str(exp)}
+
+ config.ui.end_task(return_response)
+ reset_state_task(request)
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=f"/datasets/ui/display/{entity_id}",
+ )
+ return return_response
diff --git a/cli/medperf/web_ui/events.py b/cli/medperf/web_ui/events.py
index 58a1d987c..7e52bbc12 100644
--- a/cli/medperf/web_ui/events.py
+++ b/cli/medperf/web_ui/events.py
@@ -41,8 +41,11 @@ def get_task_id(request: Request, current_user: bool = Depends(check_user_api)):
def process_event(request: Request, event: EventBase):
- if request.app.state.task.running and event.task_id == request.app.state.task.id:
- request.app.state.task.add_log(event)
+ if not event.task_id:
+ return
+ active = request.app.state.active_tasks
+ if event.task_id in active:
+ active[event.task_id].add_log(event)
return
for task in request.app.state.old_tasks:
if task.id == event.task_id:
@@ -54,15 +57,20 @@ def sse_frame_event(event: EventBase):
return f"id: {event.id}\ndata: {event.json()}\n\n"
-def should_process_old(request: Request, stream_old: bool):
+def should_process_old(request: Request, task_name: str, stream_old: bool):
if not stream_old:
return
- for old_event in request.app.state.task.logs.copy():
+
+ active = request.app.state.active_tasks.get(task_name)
+ if not active:
+ return
+
+ for old_event in active.logs.copy():
yield sse_frame_event(old_event)
-def event_generator(request: Request, stream_old: bool):
- yield from should_process_old(request, stream_old)
+def event_generator(request: Request, task_name: str, stream_old: bool):
+ yield from should_process_old(request, task_name, stream_old)
while True:
event_processed = False
@@ -70,10 +78,7 @@ def event_generator(request: Request, stream_old: bool):
if anyio.from_thread.run(request.is_disconnected):
break
try:
- event = config.ui.get_event(timeout=1.0)
- if not event.task_id:
- continue
-
+ event = config.ui.get_event(task_id=task_name, timeout=1.0)
process_event(request, event)
event_processed = True
yield sse_frame_event(event)
@@ -90,11 +95,12 @@ def event_generator(request: Request, stream_old: bool):
@router.get("/events", response_class=StreamingResponse)
def stream_events(
request: Request,
+ task_name: str,
stream_old: bool = False,
current_user: bool = Depends(check_user_api),
):
return StreamingResponse(
- event_generator(request, stream_old),
+ event_generator(request, task_name, stream_old),
media_type="text/event-stream; charset=utf-8",
)
@@ -142,12 +148,16 @@ def acknowledge_event(
@router.post("/events")
def respond(
request: Request,
+ task_name: str = Form(...),
is_approved: bool = Form(...),
current_user: bool = Depends(check_user_api),
):
config.ui.set_response({"value": is_approved})
# Remove the prompt event after responding to the prompt
- for event in request.app.state.task.logs:
+ active = request.app.state.active_tasks.get(task_name)
+ if not active:
+ return
+ for event in active.logs.copy():
if event.kind == "event" and event.type == "prompt":
event.type = "prompt_done"
event.approved = is_approved
diff --git a/cli/medperf/web_ui/medperf_login.py b/cli/medperf/web_ui/medperf_login.py
index ec974ec3e..ba448a550 100644
--- a/cli/medperf/web_ui/medperf_login.py
+++ b/cli/medperf/web_ui/medperf_login.py
@@ -26,9 +26,21 @@ def login_form(
redirected: str = "false",
current_user: bool = Depends(check_user_ui),
):
+ account_info = read_user_account()
+ msg = ""
+ if account_info is not None:
+ msg = (
+ f"You are already logged in as {account_info['email']}."
+ " Logout before logging in again"
+ )
redirected = redirected.lower() == "true"
return templates.TemplateResponse(
- "medperf_login.html", {"request": request, "redirected": redirected}
+ "medperf_login.html",
+ {
+ "request": request,
+ "redirected": redirected,
+ "already_logged_in_msg": msg if account_info else None,
+ },
)
diff --git a/cli/medperf/web_ui/models/routes.py b/cli/medperf/web_ui/models/routes.py
index 65f827773..81a9cc39e 100644
--- a/cli/medperf/web_ui/models/routes.py
+++ b/cli/medperf/web_ui/models/routes.py
@@ -78,6 +78,8 @@ def model_detail_ui(
cc_config_defaults = model.get_cc_config()
cc_configured = model.is_cc_configured()
+ cc_initialized = model.is_cc_initialized()
+ cc_last_synced = model.get_last_synced()
return templates.TemplateResponse(
"model/model_detail.html",
{
@@ -92,6 +94,8 @@ def model_detail_ui(
"benchmarks": benchmarks,
"cc_config_defaults": cc_config_defaults,
"cc_configured": cc_configured,
+ "cc_initialized": cc_initialized,
+ "cc_last_synced": cc_last_synced,
},
)
@@ -104,7 +108,7 @@ def associate(
current_user: bool = Depends(check_user_api),
):
initialize_state_task(request, task_name="model_association")
- return_response = {"status": "", "error": ""}
+ return_response = {"status": "", "error": "", "entity_id": model_id}
try:
AssociateModel.run(model_uid=model_id, benchmark_uid=benchmark_id)
return_response["status"] = "success"
@@ -129,7 +133,7 @@ def associate(
def edit_cc_config(
request: Request,
entity_id: int = Form(...),
- require_cc: bool = Form(False),
+ configure_cc: bool = Form(False),
project_id: str = Form(""),
project_number: str = Form(""),
bucket: str = Form(""),
@@ -150,7 +154,7 @@ def edit_cc_config(
"wip": wip,
"wip_provider": wip_provider,
}
- if not require_cc:
+ if not configure_cc:
args = {}
initialize_state_task(request, task_name="model_update_cc_config")
@@ -177,12 +181,27 @@ def edit_cc_config(
@router.post("/sync_cc_policy", response_class=JSONResponse)
def sync_cc_policy(
+ request: Request,
entity_id: int = Form(...),
current_user: bool = Depends(check_user_api),
):
+ initialize_state_task(request, task_name="model_update_cc_policy")
+ return_response = {"status": "", "error": ""}
try:
ModelUpdateCCPolicy.run(entity_id)
- return {"status": "success", "error": ""}
+ return_response["status"] = "success"
+ notification_message = "Successfully updated model CC policy!"
except Exception as exp:
+ return_response["status"] = "failed"
+ return_response["error"] = str(exp)
+ notification_message = "Failed to update model CC policy"
logger.exception(exp)
- return {"status": "failed", "error": str(exp)}
+
+ config.ui.end_task(return_response)
+ reset_state_task(request)
+ config.ui.add_notification(
+ message=notification_message,
+ return_response=return_response,
+ url=f"/models/ui/display/{entity_id}",
+ )
+ return return_response
diff --git a/cli/medperf/web_ui/schemas.py b/cli/medperf/web_ui/schemas.py
index 92685c47d..d8e7172c1 100644
--- a/cli/medperf/web_ui/schemas.py
+++ b/cli/medperf/web_ui/schemas.py
@@ -245,35 +245,57 @@ def mark_notification_as_read(self, notification_id) -> None:
return
+def _buffer_key(task_id: Optional[str]) -> str:
+ """Key for per-task buffer; chunks must not mix events from different tasks."""
+ return task_id if task_id else "_"
+
+
class EventsManager:
+ """
+ Buffers chunkable events (interactive prints) per task_id, flushes by size/age or
+ when a non-chunkable event arrives. Events are enqueued to a per-task queue so
+ each SSE stream (one per task) only receives that task's events.
+
+ flush_all_buffers: before enqueueing a non-chunkable event we flush every task's
+ buffer so chunkable events are sent first. Also used when a task ends.
+ """
+
def __init__(self):
- self.events: Queue[EventBase] = Queue()
- self.buffer: List[Event] = []
- self.size = 0 # For bytes check
- self.created_at = 0 # For age check - will be time.monotonic()
+ self._event_queues: Dict[str, Queue] = {} # _buffer_key(task_id) -> Queue[EventBase]
+ self._queues_lock = threading.Lock()
+ self._buffers: Dict[str, Dict] = {} # _buffer_key(task_id) -> {"events", "size", "created_at"}
self.lock = threading.Lock()
- self.stop_event = threading.Event()
+ self._stop_event = threading.Event()
+ self._age_worker_started = False
self.max_chunk_length = config.webui_max_chunk_length
self.max_chunk_age = config.webui_max_chunk_age
self.max_chunk_size = config.webui_max_chunk_size
- def add_event(self, event: Event):
- """Append an event to the chunk buffer and update its size.
+ def _queue_for(self, task_id: Optional[str]) -> Queue:
+ key = _buffer_key(task_id)
+ with self._queues_lock:
+ if key not in self._event_queues:
+ self._event_queues[key] = Queue()
+ return self._event_queues[key]
- If the buffer is empty, set its created_at timestamp (monotonic) to now.
+ def _get_or_create_buffer(self, key: str) -> Dict:
+ if key not in self._buffers:
+ self._buffers[key] = {"events": [], "size": 0, "created_at": 0.0}
+ return self._buffers[key]
- Args:
- event (Event): The event to buffer (typically an interactive print line).
- """
+ def add_event(self, event: Event):
+ """Append an event to the chunk buffer for its task_id."""
- if not self.buffer:
- self.created_at = time.monotonic()
- self.buffer.append(event)
- self.size += event.get_size_bytes()
+ key = _buffer_key(event.task_id)
+ buf = self._get_or_create_buffer(key)
+ if not buf["events"]:
+ buf["created_at"] = time.monotonic()
+ buf["events"].append(event)
+ buf["size"] += event.get_size_bytes()
def process_event(self, event: Event):
"""
- Process a single event: buffer chunkable events, flush if needed,
+ Process a single event: buffer chunkable events per task_id, flush if needed,
or immediately enqueue non-chunkable events.
"""
@@ -284,30 +306,20 @@ def process_event(self, event: Event):
return
with self.lock:
- self.flush_buffer()
+ self.flush_all_buffers()
self.enqueue_event(event)
def enqueue_event(self, event: EventBase):
- """Enqueue an event (chunked or single) into the events queue.
-
- Args:
- event (EventBase): The event or chunk to enqueue.
- """
- self.events.put_nowait(event)
+ """Enqueue an event into the queue for its task_id so only that task's SSE receives it."""
+ self._queue_for(event.task_id).put_nowait(event)
- def dequeue_event(self, timeout: Optional[float]) -> Optional[EventBase]:
- """
- Return the next event or chunk from the queue.
-
- Args:
- timeout (float | None): Seconds to wait for an event.
- """
-
- return self.events.get(block=True, timeout=timeout)
+ def dequeue_event(self, task_id: Optional[str], timeout: Optional[float]) -> Optional[EventBase]:
+ """Return the next event or chunk for the given task_id (used by SSE stream)."""
+ return self._queue_for(task_id).get(block=True, timeout=timeout)
def _build_chunk(self, events: List[Event], size: int) -> EventChunk:
- """Build an EventChunk object from a list of events."""
+ """Build an EventChunk from events (all same task_id)."""
for i, ev in enumerate(events, 1):
ev.id = i
@@ -319,66 +331,68 @@ def _build_chunk(self, events: List[Event], size: int) -> EventChunk:
size_bytes=size,
)
- def flush_buffer(self):
- """
- Flush the event buffer: if it contains multiple events, emit them as
- an EventChunk; otherwise enqueue the single event. The buffer is then reset.
- """
- if not self.buffer:
+ def _flush_one_buffer(self, key: str) -> None:
+ """Flush a single task's buffer: emit chunk or single event, then clear it."""
+ buf = self._buffers.get(key)
+ if not buf or not buf["events"]:
return
- buffer = list(self.buffer)
- size = self.size
- self.buffer.clear()
- self.size = 0
- self.created_at = 0
-
- if len(buffer) != 1:
- chunk = self._build_chunk(buffer, size)
- self.events.put_nowait(chunk)
- return
-
- self.events.put_nowait(buffer[0])
-
- def flush_by_size(self):
- """Flush the buffer if it exceeds the max event count or max byte size."""
-
- length_exceeded = len(self.buffer) >= self.max_chunk_length
- size_bytes_exceeded = self.size >= self.max_chunk_size
-
- if self.buffer and (length_exceeded or size_bytes_exceeded):
- self.flush_buffer()
-
- def flush_by_age(self):
- """Flush the buffer if the oldest event exceeds the max age."""
-
- time_exceeded = (time.monotonic() - self.created_at) >= self.max_chunk_age
- if self.buffer and time_exceeded:
- self.flush_buffer()
-
- def flush_by_age_worker(self):
- """
- Background loop that periodically flushes the buffer if it exceeds max age.
-
- Runs until 'stop_event' is set.
- """
-
- while not self.stop_event.is_set():
+ events = list(buf["events"])
+ size = buf["size"]
+ buf["events"].clear()
+ buf["size"] = 0
+ buf["created_at"] = 0.0
+ self._buffers.pop(key, None) # remove empty buffer
+
+ task_id = events[0].task_id
+ if len(events) != 1:
+ self._queue_for(task_id).put_nowait(self._build_chunk(events, size))
+ else:
+ self._queue_for(task_id).put_nowait(events[0])
+
+ def flush_all_buffers(self) -> None:
+ """Flush every task's buffer (e.g. before a non-chunkable event)."""
+ for key in list(self._buffers.keys()):
+ self._flush_one_buffer(key)
+
+ def flush_by_size(self) -> None:
+ """Flush any task's buffer that exceeds max length or byte size."""
+ for key in list(self._buffers.keys()):
+ buf = self._buffers.get(key)
+ if not buf or not buf["events"]:
+ continue
+ length_ok = len(buf["events"]) >= self.max_chunk_length
+ size_ok = buf["size"] >= self.max_chunk_size
+ if length_ok or size_ok:
+ self._flush_one_buffer(key)
+
+ def flush_by_age(self) -> None:
+ """Flush any task's buffer whose oldest event exceeds max age."""
+ now = time.monotonic()
+ for key in list(self._buffers.keys()):
+ buf = self._buffers.get(key)
+ if not buf or not buf["events"]:
+ continue
+ if (now - buf["created_at"]) >= self.max_chunk_age:
+ self._flush_one_buffer(key)
+
+ def _flush_by_age_worker(self) -> None:
+ """Background loop: periodically flush buffers that exceed max age."""
+ while not self._stop_event.is_set():
with self.lock:
self.flush_by_age()
time.sleep(0.2)
- def start_buffering(self):
- """Start the age-based flushing worker thread."""
-
- self.stop_event.clear()
-
- # worker for age-flushing
- age_worker = threading.Thread(target=self.flush_by_age_worker, daemon=True)
- age_worker.start()
-
- def stop_buffering(self):
- """Stop the age-flushing worker and flush any remaining buffered events."""
+ def start_buffering(self) -> None:
+ """Start the age-based flushing worker (idempotent; one worker for all tasks)."""
+ with self.lock:
+ if self._age_worker_started:
+ return
+ self._age_worker_started = True
+ self._stop_event.clear()
+ t = threading.Thread(target=self._flush_by_age_worker, daemon=True)
+ t.start()
- self.stop_event.set()
+ def stop_buffering(self) -> None:
+ """Flush all task buffers (e.g. when a task ends). Does not stop the age worker."""
with self.lock:
- self.flush_buffer()
+ self.flush_all_buffers()
diff --git a/cli/medperf/web_ui/settings.py b/cli/medperf/web_ui/settings.py
index 4717fd4a6..dfb75a942 100644
--- a/cli/medperf/web_ui/settings.py
+++ b/cli/medperf/web_ui/settings.py
@@ -37,6 +37,7 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)):
cc_config_defaults = {}
cc_configured = False
+ cc_initialized = False
if is_logged_in():
cas = CA.all()
@@ -46,6 +47,7 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)):
user = get_medperf_user_object()
cc_config_defaults = user.get_cc_config()
cc_configured = user.is_cc_configured()
+ cc_initialized = user.is_cc_initialized()
return templates.TemplateResponse(
"settings.html",
@@ -60,6 +62,7 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)):
"certificate_status": certificate_status,
"cc_config_defaults": cc_config_defaults,
"cc_configured": cc_configured,
+ "cc_initialized": cc_initialized,
},
)
@@ -221,7 +224,7 @@ def submit_certificate(
@router.post("/edit_cc_operator", response_class=JSONResponse)
def edit_cc_operator(
- require_cc: bool = Form(...),
+ configure_cc: bool = Form(False),
project_id: str = Form(""),
service_account_name: str = Form(""),
bucket: str = Form(""),
@@ -236,7 +239,7 @@ def edit_cc_operator(
"vm_zone": vm_zone,
"vm_name": vm_name,
}
- if not require_cc:
+ if not configure_cc:
args = {}
try:
diff --git a/cli/medperf/web_ui/static/css/common.css b/cli/medperf/web_ui/static/css/common.css
index ce5912706..26286e270 100644
--- a/cli/medperf/web_ui/static/css/common.css
+++ b/cli/medperf/web_ui/static/css/common.css
@@ -1,421 +1,517 @@
-:root {
- --footer-h: 160px;
-}
-html, body {
- height: 100%;
-}
-
-.page-wrap {
- min-height: calc(100vh - var(--footer-h));
- box-sizing: border-box;
-}
-
-footer.site-footer {
- height: var(--footer-h);
-}
-
-.badge-state-operational {
- background-color: lightgreen;
- color: black;
-}
-
-.badge-state-development {
- background-color: lightcoral;
- color: black;
-}
-
-.badge-approval-pending {
- background-color: coral;
- color: black;
-}
-
-.badge-approval-approved {
- background-color: lightgreen;
- color: black;
-}
-
-.badge-approval-rejected {
- background-color: red;
- color: black;
-}
-
-.badge-valid {
- background-color: lightgreen;
- color: black;
-}
-
-.badge-invalid {
- background-color: red;
- color: black;
-}
-
-.invalid-card {
- background-color: #ffe6e6;
-}
-
-.unfinalized-result-card {
- background-color: #ffe6e6;
-}
-
-.page-container {
- display: flex;
- flex-direction: column;
- min-height: 100vh;
-}
-
-.main-content {
- flex: 1;
- display: flex;
- flex-direction: column;
- align-items: center;
- padding: 20px;
-}
-
-.detail-container {
- display: flex;
- justify-content: center; /* Center the detail panel initially */
- width: 100%;
- align-items: center;
-}
-
-.detail-panel {
- flex: 1;
- max-width: 60%;
- padding-right: 20px;
- box-sizing: border-box;
- transition: max-width 0.3s, margin 0.3s; /* Smooth transition */
-}
-
-.yaml-panel {
- flex: 0 0 40%;
- max-width: 40%;
- min-width: 300px;
- height: 100vh;
- overflow: auto;
- display: none; /* Hidden initially */
-}
-
-.yaml-panel-visible .yaml-panel {
- justify-content: flex-start; /* Align detail panel to the left when YAML panel is visible */
-}
-
-.yaml-panel-visible .detail-panel {
- max-width: 60%; /* Adjust to take 60% width when YAML panel is visible */
- margin-right: 20px;
-}
-
-pre {
- white-space: pre-wrap;
-}
-
-/* Floating alert container at the bottom-right corner */
-.floating-alert {
- position: fixed;
- bottom: 20px;
- right: 20px;
- z-index: 1050; /* Ensure it appears above other content */
- width: 300px;
- max-width: 100%;
-}
-
-/*----from detail_base.html*/
-.card-body.d-flex {
- flex-wrap: wrap;
-}
-
-.card-text {
- word-wrap: break-word;
- white-space: normal;
-}
-
-.association-card {
- margin-bottom: 20px;
-}
-
-.benchmark-result-card {
- margin-bottom: 20px;
-}
-
-.associations-panel {
- display: flex;
- flex-wrap: wrap;
- justify-content: space-between;
- width: 100%;
- margin-top: 20px;
-}
-
-.associations-column {
- flex: 1;
- min-width: 150px;
- max-width: 48%;
- box-sizing: border-box;
-}
-
-#datasets-associations{
- display: flex;
- flex-wrap: wrap;
- justify-content: space-evenly;
-}
-
-#models-associations{
- display: flex;
- flex-wrap: wrap;
- justify-content: space-evenly;
-}
-
-#benchmark-results{
- display: flex;
- flex-wrap: wrap;
- justify-content: space-evenly;
-}
-
-/* Step container styling */
-.step-container {
- border-radius: 8px;
- background-color: #f9f9f9;
- padding: 10px;
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
- display: flex;
- justify-content: space-between;
- align-items: center;
-}
-
-/* Step buttons */
-.step-btn {
- padding: 5px 15px; /* Reduced padding */
- font-size: 0.9rem; /* Slightly smaller font for a more compact design */
- border-radius: 20px;
- transition: background-color 0.3s ease;
-}
-
-.step-btn i {
- font-size: 1rem;
-}
-
-/* Completed steps */
-.step-complete {
- display: flex;
- align-items: center;
- background-color: #e8f5e9;
- padding: 5px 15px; /* Reduced padding */
- border-radius: 20px;
-}
-
-.step-complete i {
- font-size: 1.2rem;
-}
-
-/* Step labels */
-.step-label {
- margin-top: 2px;
- font-size: 0.85rem; /* Smaller step labels */
- color: #6c757d;
-}
-
-/* Step dividers */
-.step-divider i {
- font-size: 1.2rem;
- margin-left: 5px;
- margin-right: 5px;
-}
-
-/* Consistent button styles */
-.step-btn, .step-complete {
- min-width: 120px; /* Adjusted width */
- text-align: center;
-}
-
-.step-btn:disabled {
- background-color: #e0e0e0;
- color: #6c757d;
-}
-
-/* Dropdown button adjustments */
-.dropdown-toggle {
- padding: 5px 15px;
- border-radius: 20px;
-}
-
-.hidden-element{
- display: none;
-}
-
-.bottom-buttons-panel{
- border:1px solid rgba(0, 0, 0, .125);
- border-radius:.25rem;
-}
-
-.benchmarks-dropdown li:active{
- background-color: unset;
- color: unset;
-}
-
-.tooltip-info:hover{
- cursor: help;
-}
-
-.dropdown-item:active{
- background-color: initial;
- color: initial;
-}
-
-.floating-alert {
- position: fixed;
- top: 20px;
- left: 50%;
- transform: translateX(-50%);
- z-index: 1055;
- width: fit-content;
- max-width: 90%;
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
- height: fit-content;
-}
-
-.language-yaml{
- background-color: #f8f9fa;
- padding: 15px;
- border-radius: 4px;
-}
-
-.modal-body {
- word-wrap: break-word;
- overflow-wrap: break-word;
- white-space: normal;
-}
-
-#notifications-list {
- max-height: 360px;
- overflow-y:auto;
-}
-
-#notifications-list > li{
- width: 340px!important;
-}
-
-.notification-text {
- word-wrap: break-word;
- overflow-wrap: break-word;
- max-width: 100%;
- white-space: normal;
-}
-
-#log-panel{
- height: 300px;
- overflow-y: scroll;
- background-color: #f8f9fa;
- padding: 15px;
-}
-
-#yaml-content{
- max-height: 300px;
- overflow-y: scroll;
-}
-
-#result-content{
- max-height: 300px;
- overflow-y: scroll;
-}
-
-label[for="switch"] {
- margin-left: -40px;
- cursor: pointer;
-}
-
-#switch {
- border: 1px solid gray;
- cursor: pointer;
-}
-
-#toast-container {
- z-index: 11;
-}
-
-#folder-list{
- max-height: 600px;
- overflow-y: auto;
-}
-
-.folder-item{
- cursor: pointer;
-}
-
-.file-item{
- color: grey;
-}
-
-.parent-disabled{
- cursor: not-allowed;
-}
-
-#folder-picker-modal-title{
- word-break: break-all;
-}
-
-.email-chip {
- display: inline-block;
- background-color: #e0e0e0;
- border-radius: 20px;
- padding: 5px 10px;
- margin: 2px;
- font-size: 14px;
- position: relative;
-}
-
-.email-chip .remove-btn {
- margin-left: 8px;
- cursor: pointer;
- font-weight: bold;
-}
-
-.email-container {
- display: flex;
- flex-wrap: wrap;
- padding: 5px;
- border: 1px solid #ccc;
- border-radius: 5px;
-}
-
-#datasets-associations-title, #models-associations-title, #benchmark-results-title{
- cursor: pointer;
-}
-
-.bmk-detail-toggle {
- transition: transform 0.3s ease;
-}
-
-.bmk-detail-toggle.rotated {
- transform: rotate(180deg);
-}
-
-#modal-yaml-content{
- max-height: 450px;
-}
-
-@keyframes spin {
- to {
- transform: rotate(360deg);
- }
-}
-
-@media (max-width: 992px) {
- .detail-panel {
- max-width: 100%;
- padding-right: 0;
- }
-
- .yaml-panel {
- display: none;
- }
-
- .detail-panel, .associations-panel {
- max-width: 100%;
- padding-right: 0;
- padding-left: 0;
- }
-}
-
-@media (max-width: 768px) {
- .detail-container {
- flex-direction: column;
- }
+:root {
+ --footer-h: 160px;
+}
+html, body {
+ height: 100%;
+}
+
+.page-wrap {
+ min-height: calc(100vh - var(--footer-h));
+ box-sizing: border-box;
+}
+
+footer.site-footer {
+ height: var(--footer-h);
+}
+
+.badge-state-operational {
+ background-color: lightgreen;
+ color: black;
+}
+
+.badge-state-development {
+ background-color: lightcoral;
+ color: black;
+}
+
+.badge-approval-pending {
+ background-color: coral;
+ color: black;
+}
+
+.badge-approval-approved {
+ background-color: lightgreen;
+ color: black;
+}
+
+.badge-approval-rejected {
+ background-color: red;
+ color: black;
+}
+
+.badge-valid {
+ background-color: lightgreen;
+ color: black;
+}
+
+.badge-invalid {
+ background-color: red;
+ color: black;
+}
+
+.invalid-card {
+ background-color: #fef2f2;
+}
+.dark .invalid-card {
+ background-color: rgba(127, 29, 29, 0.2);
+}
+
+.unfinalized-result-card {
+ background-color: #ffe6e6;
+}
+.dark .unfinalized-result-card {
+ background-color: rgba(127, 29, 29, 0.3);
+}
+
+.page-container {
+ display: flex;
+ flex-direction: column;
+ min-height: 100vh;
+}
+
+.main-content {
+ flex: 1;
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ padding: 20px;
+}
+
+.detail-container {
+ display: flex;
+ justify-content: center;
+ width: 100%;
+ align-items: center;
+}
+
+.detail-panel {
+ flex: 1;
+ width: 100%;
+ max-width: 100%;
+ box-sizing: border-box;
+}
+
+pre {
+ white-space: pre-wrap;
+}
+
+.floating-alert {
+ position: fixed;
+ bottom: 20px;
+ right: 20px;
+ z-index: 1060;
+ width: 300px;
+ max-width: 100%;
+}
+
+.card-body.d-flex {
+ flex-wrap: wrap;
+}
+
+.card-text {
+ word-wrap: break-word;
+ white-space: normal;
+}
+
+.association-card {
+ margin-bottom: 20px;
+}
+
+.benchmark-result-card {
+ margin-bottom: 20px;
+}
+
+.associations-panel {
+ display: flex;
+ flex-wrap: wrap;
+ justify-content: space-between;
+ width: 100%;
+ margin-top: 20px;
+}
+
+.associations-column {
+ flex: 1;
+ min-width: 150px;
+ max-width: 48%;
+ box-sizing: border-box;
+}
+
+#datasets-associations,
+#models-associations,
+#benchmark-results {
+ display: flex;
+ flex-wrap: wrap;
+ justify-content: space-evenly;
+}
+
+.step-container {
+ border-radius: 8px;
+ background-color: #f9f9f9;
+ padding: 10px;
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+}
+.dark .step-container {
+ background-color: #374151;
+}
+
+.step-btn {
+ transition: opacity 0.2s ease, background-color 0.2s ease;
+}
+
+.step-complete {
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ background-color: #dcfce7;
+ color: #166534;
+ border: 2px solid #86efac;
+}
+.dark .step-complete {
+ background-color: rgba(34, 197, 94, 0.2);
+ color: #86efac;
+ border-color: rgba(34, 197, 94, 0.4);
+}
+
+.step-label {
+ color: #6b7280;
+ line-height: 1.4;
+}
+.dark .step-label {
+ color: #9ca3af;
+}
+
+.step-divider i {
+ font-size: 1rem;
+ opacity: 0.7;
+}
+
+button:disabled,
+.step-btn:disabled {
+ opacity: 0.65;
+ cursor: not-allowed;
+ filter: saturate(0.6);
+}
+.dark button:disabled,
+.dark .step-btn:disabled {
+ opacity: 0.55;
+ filter: saturate(0.5) brightness(0.85);
+}
+.step-btn:disabled {
+ background-color: #b0b0b0 !important;
+ color: #6c757d !important;
+}
+.dark .step-btn:disabled {
+ background-color: #4b5563 !important;
+ color: #9ca3af !important;
+}
+
+.dropdown-toggle {
+ padding: 5px 15px;
+ border-radius: 20px;
+}
+
+.hidden-element{
+ display: none;
+}
+
+.bottom-buttons-panel{
+ border:1px solid rgba(0, 0, 0, .125);
+ border-radius:.25rem;
+}
+
+.benchmarks-dropdown li:active{
+ background-color: unset;
+ color: unset;
+}
+
+.tooltip-info:hover,
+.tooltip-icon {
+ cursor: help;
+}
+.tooltip-icon i {
+ font-size: 1.125rem;
+}
+.step-btn,
+.view-result-btn,
+.run-all-btn,
+.yaml-link {
+ cursor: pointer;
+}
+.step-btn:disabled,
+button:disabled {
+ cursor: not-allowed;
+}
+#page-modal-footer button,
+#page-modal-close-btn,
+.close-modal-btn,
+.modal-footer button,
+[data-dismiss-modal] {
+ cursor: pointer;
+}
+
+.dropdown-item:active{
+ background-color: initial;
+ color: initial;
+}
+
+.floating-alert {
+ position: fixed;
+ top: 20px;
+ left: 50%;
+ transform: translateX(-50%);
+ z-index: 1060;
+ width: fit-content;
+ max-width: 90%;
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
+ height: fit-content;
+}
+
+/* displayAlert toasts: same area as toast-container, auto-dismiss + progress */
+.display-alert {
+ position: relative;
+ display: flex;
+ align-items: center;
+ gap: 0.75rem;
+ min-width: 280px;
+ max-width: 420px;
+ padding: 0.875rem 1rem;
+ border-radius: 12px;
+ box-shadow: 0 10px 25px rgba(0, 0, 0, 0.15), 0 4px 10px rgba(0, 0, 0, 0.08);
+ color: #fff;
+ font-weight: 500;
+ overflow: hidden;
+ animation: display-alert-in 0.3s ease-out;
+}
+.display-alert-out {
+ animation: display-alert-out 0.28s ease-in forwards;
+}
+@keyframes display-alert-in {
+ from {
+ opacity: 0;
+ transform: translateY(-12px);
+ }
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+@keyframes display-alert-out {
+ from { opacity: 1; }
+ to { opacity: 0; }
+}
+.display-alert-icon {
+ flex-shrink: 0;
+ width: 1.5rem;
+ height: 1.5rem;
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ border-radius: 50%;
+ background: rgba(255, 255, 255, 0.25);
+ font-size: 0.875rem;
+ font-weight: bold;
+}
+.display-alert-message {
+ flex: 1;
+ min-width: 0;
+ line-height: 1.35;
+}
+.display-alert-close {
+ flex-shrink: 0;
+ width: 1.75rem;
+ height: 1.75rem;
+ padding: 0;
+ border: none;
+ border-radius: 8px;
+ background: rgba(255, 255, 255, 0.2);
+ color: inherit;
+ font-size: 1.25rem;
+ line-height: 1;
+ cursor: pointer;
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ transition: background 0.15s ease;
+}
+.display-alert-close:hover {
+ background: rgba(255, 255, 255, 0.35);
+}
+.display-alert-progress {
+ position: absolute;
+ bottom: 0;
+ left: 0;
+ height: 3px;
+ background: rgba(255, 255, 255, 0.4);
+ animation: display-alert-progress-shrink 5s linear forwards;
+ transform-origin: left;
+}
+@keyframes display-alert-progress-shrink {
+ from { transform: scaleX(1); }
+ to { transform: scaleX(0); }
+}
+.display-alert-success {
+ background: linear-gradient(135deg, #0d9488 0%, #059669 100%);
+}
+.display-alert-danger {
+ background: linear-gradient(135deg, #dc2626 0%, #b91c1c 100%);
+}
+.display-alert-warning {
+ background: linear-gradient(135deg, #d97706 0%, #b45309 100%);
+}
+.display-alert-info {
+ background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%);
+}
+.display-alert-fixed {
+ position: fixed;
+ top: 5rem;
+ left: 50%;
+ transform: translateX(-50%);
+ z-index: 1060;
+}
+
+.language-yaml {
+ padding: 15px;
+ border-radius: 4px;
+ background-color: unset!important;
+}
+
+.language-json {
+ padding: 15px;
+ border-radius: 4px;
+ background-color: unset!important;
+}
+
+.modal-body {
+ word-wrap: break-word;
+ overflow-wrap: break-word;
+ white-space: normal;
+}
+
+#notifications-list {
+ max-height: 70vh;
+ overflow-y: auto;
+ overflow-x: hidden;
+}
+
+#notifications-list > li {
+ max-width: 100%;
+ box-sizing: border-box;
+}
+
+#notifications-list > li > div {
+ min-width: 0;
+}
+
+.notification-text {
+ word-wrap: break-word;
+ overflow-wrap: break-word;
+ max-width: 100%;
+ white-space: normal;
+}
+
+#log-panel {
+ height: 300px;
+ overflow-y: scroll;
+ background-color: #f8f9fa;
+ padding: 15px;
+}
+.dark #log-panel {
+ background-color: #374151;
+}
+
+#yaml-content{
+ max-height: 300px;
+ overflow-y: scroll;
+}
+
+#result-content{
+ max-height: 300px;
+ overflow-y: scroll;
+}
+
+label[for="switch"] {
+ margin-left: -40px;
+ cursor: pointer;
+}
+
+#switch {
+ border: 1px solid gray;
+ cursor: pointer;
+}
+
+#folder-list{
+ max-height: 600px;
+ overflow-y: auto;
+}
+
+.folder-item{
+ cursor: pointer;
+}
+
+.file-item{
+ color: grey;
+}
+
+.parent-disabled{
+ cursor: not-allowed;
+}
+
+#folder-picker-modal-title{
+ word-break: break-all;
+}
+
+.email-chip {
+ display: inline-block;
+ background-color: #e0e0e0;
+ border-radius: 20px;
+ padding: 5px 10px;
+ margin: 2px;
+ font-size: 14px;
+ position: relative;
+}
+
+.email-chip .remove-btn {
+ margin-left: 8px;
+ cursor: pointer;
+ font-weight: bold;
+}
+
+.email-container {
+ display: flex;
+ flex-wrap: wrap;
+ padding: 5px;
+ border: 1px solid #ccc;
+ border-radius: 5px;
+}
+.dark .email-container {
+ border-color: #4b5563;
+}
+
+
+#modal-yaml-content{
+ max-height: 450px;
+}
+
+@keyframes spin {
+ to {
+ transform: rotate(360deg);
+ }
+}
+
+@media (max-width: 992px) {
+ .detail-panel,
+ .associations-panel {
+ max-width: 100%;
+ padding-right: 0;
+ padding-left: 0;
+ }
+}
+
+@media (max-width: 768px) {
+ .detail-container {
+ flex-direction: column;
+ }
}
\ No newline at end of file
diff --git a/cli/medperf/web_ui/static/css/prism-dark.css b/cli/medperf/web_ui/static/css/prism-dark.css
new file mode 100644
index 000000000..80b623b55
--- /dev/null
+++ b/cli/medperf/web_ui/static/css/prism-dark.css
@@ -0,0 +1,161 @@
+/**
+Taken from: https://cdnjs.cloudflare.com/ajax/libs/prism-themes/1.5.0/prism-a11y-dark.min.css
+added: token.operator background color
+*/
+code[class*=language-],
+pre[class*=language-] {
+ color: #f8f8f2;
+ background: 0 0;
+ font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
+ text-align: left;
+ white-space: pre;
+ word-spacing: normal;
+ word-break: normal;
+ word-wrap: normal;
+ line-height: 1.5;
+ -moz-tab-size: 4;
+ -o-tab-size: 4;
+ tab-size: 4;
+ -webkit-hyphens: none;
+ -moz-hyphens: none;
+ -ms-hyphens: none;
+ hyphens: none
+}
+
+pre[class*=language-] {
+ padding: 1em;
+ margin: .5em 0;
+ overflow: auto;
+ border-radius: .3em
+}
+
+:not(pre)>code[class*=language-],
+pre[class*=language-] {
+ background: #2b2b2b
+}
+
+:not(pre)>code[class*=language-] {
+ padding: .1em;
+ border-radius: .3em;
+ white-space: normal
+}
+
+.token.cdata,
+.token.comment,
+.token.doctype,
+.token.prolog {
+ color: #d4d0ab
+}
+
+.token.punctuation {
+ color: #fefefe
+}
+
+.token.constant,
+.token.deleted,
+.token.property,
+.token.symbol,
+.token.tag {
+ color: #ffa07a
+}
+
+.token.boolean,
+.token.number {
+ color: #00e0e0
+}
+
+.token.attr-name,
+.token.builtin,
+.token.char,
+.token.inserted,
+.token.selector,
+.token.string {
+ color: #abe338
+}
+
+.language-css .token.string,
+.style .token.string,
+.token.entity,
+.token.operator,
+.token.url,
+.token.variable {
+ color: #00e0e0
+}
+
+.token.operator {
+ background-color: var(--bs-gray-700);
+}
+
+
+.token.atrule,
+.token.attr-value,
+.token.function {
+ color: gold
+}
+
+.token.keyword {
+ color: #00e0e0
+}
+
+.token.important,
+.token.regex {
+ color: gold
+}
+
+.token.bold,
+.token.important {
+ font-weight: 700
+}
+
+.token.italic {
+ font-style: italic
+}
+
+.token.entity {
+ cursor: help
+}
+
+@media screen and (-ms-high-contrast:active) {
+
+ code[class*=language-],
+ pre[class*=language-] {
+ color: windowText;
+ background: window
+ }
+
+ :not(pre)>code[class*=language-],
+ pre[class*=language-] {
+ background: window
+ }
+
+ .token.important {
+ background: highlight;
+ color: window;
+ font-weight: 400
+ }
+
+ .token.atrule,
+ .token.attr-value,
+ .token.function,
+ .token.keyword,
+ .token.operator,
+ .token.selector {
+ font-weight: 700
+ }
+
+ .token.attr-value,
+ .token.comment,
+ .token.doctype,
+ .token.function,
+ .token.keyword,
+ .token.operator,
+ .token.property,
+ .token.string {
+ color: highlight
+ }
+
+ .token.attr-value,
+ .token.url {
+ font-weight: 400
+ }
+}
\ No newline at end of file
diff --git a/cli/medperf/web_ui/static/js/aggregators/aggregator_detail.js b/cli/medperf/web_ui/static/js/aggregators/aggregator_detail.js
new file mode 100644
index 000000000..28c21c04c
--- /dev/null
+++ b/cli/medperf/web_ui/static/js/aggregators/aggregator_detail.js
@@ -0,0 +1,126 @@
+const RUNNING_TASKS_POLL_MS = 2000;
+const AGGREGATOR_CONTAINER_TASK_NAME = "start_aggregator";
+var pollingIntervalId = null;
+
+function updateRunningBanner(tasks) {
+ var banner = document.getElementById("aggregator-running-banner");
+ var runCard = document.getElementById("aggregator-run-card");
+ if (!banner || !runCard) return;
+ var running = Array.isArray(tasks) && tasks.indexOf(AGGREGATOR_CONTAINER_TASK_NAME) !== -1;
+ if (running) {
+ banner.classList.remove("hidden");
+ runCard.classList.add("opacity-80");
+ } else {
+ banner.classList.add("hidden");
+ runCard.classList.remove("opacity-80");
+ var submit = runCard.querySelector('form[action*="/aggregators/run"] button[type="submit"]');
+ if (submit && window.taskName !== RUN_AGGREGATOR_TASK_ID) submit.disabled = false;
+ if (pollingIntervalId && window.taskName !== RUN_AGGREGATOR_TASK_ID) { clearInterval(pollingIntervalId); pollingIntervalId = null; }
+ }
+}
+
+function pollRunningTasks() {
+ fetch("/api/running_tasks", { method: "GET" })
+ .then(function (r) { return r.json(); })
+ .then(function (data) {
+ if (data && Array.isArray(data.tasks)) updateRunningBanner(data.tasks);
+ })
+ .catch(function () {});
+}
+
+function startPollingRunningTasks() {
+ pollRunningTasks();
+ if (!pollingIntervalId) pollingIntervalId = setInterval(pollRunningTasks, RUNNING_TASKS_POLL_MS);
+}
+
+/** Task ID for event stream (matches backend active_tasks / task_id). */
+const RUN_AGGREGATOR_TASK_ID = "start_aggregator";
+/** Task name in running_containers (matches cube.run task=). Used for polling and stop. */
+const GET_SERVER_CERT_TASK_NAME = "aggregator_get_server_cert";
+
+function getAggregatorId(form) {
+ var input = form ? form.querySelector('input[name="aggregator_id"]') : null;
+ return input ? input.value : null;
+}
+
+function onGetServerCertSuccess(response) {
+ if (response.status === "success") {
+ showReloadModal({ title: "Server Certificate Retrieved Successfully", seconds: 3 });
+ } else {
+ showErrorModal("Failed to Get Server Certificate", response);
+ }
+}
+
+function onRunAggregatorSuccess(response) {
+ if (response && response.status === "started") {
+ displayAlert("success", "Aggregator worker started successfully.");
+ startPollingRunningTasks();
+ } else showErrorModal("Something went wrong while running the aggregator", response);
+}
+
+function submitActionFormWithForm(form) {
+ var formData = new FormData(form);
+ var panelTitle = form.getAttribute("data-panel-title") || "Action";
+ var isRunForm = (form.getAttribute("action") || "").indexOf("/aggregators/run") !== -1;
+
+ disableElements(".detail-container form button, .detail-container form input, .detail-container form select");
+ var submitBtn = form.querySelector('button[type="submit"]');
+ if (submitBtn) addSpinner(submitBtn);
+ showPanel(panelTitle + "...");
+
+ var successCallback = isRunForm ? onRunAggregatorSuccess : onGetServerCertSuccess;
+ window.onPromptComplete = successCallback;
+ window.taskName = isRunForm ? RUN_AGGREGATOR_TASK_ID : GET_SERVER_CERT_TASK_NAME;
+
+ ajaxRequest(
+ form.action,
+ "POST",
+ formData,
+ successCallback,
+ "Error: " + panelTitle
+ );
+ streamEvents(logPanel, stagesList, currentStageElement);
+ if (isRunForm) startPollingRunningTasks();
+}
+
+function submitActionForm(e) {
+ e.preventDefault();
+ var form = e.target;
+ var msg = form.getAttribute("data-confirm-message") || "continue?";
+ showConfirmModal(form, submitActionFormWithForm, msg);
+}
+
+function stopAggregator() {
+ var btn = document.getElementById("stop-aggregator-btn");
+ btn.disabled = true;
+ var formData = new FormData();
+ formData.append("task_name", AGGREGATOR_CONTAINER_TASK_NAME);
+ fetch("/api/stop_task", { method: "POST", body: formData })
+ .then(function (r) {
+ if (r.ok) {
+ updateRunningBanner([]);
+ displayAlert("success", "Aggregator stopped.");
+ }
+ btn.disabled = false;
+ })
+ .catch(function () { btn.disabled = false; });
+}
+
+function init() {
+ var actionForms = document.querySelectorAll('#start-aggregator-form, #get-server-cert-form');
+ actionForms.forEach(function (form) {
+ form.addEventListener("submit", submitActionForm);
+ });
+
+ var stopBtn = document.getElementById("stop-aggregator-btn");
+ if (stopBtn) stopBtn.addEventListener("click", function (e) {
+ showConfirmModal(e.currentTarget, function () { stopAggregator(); }, "stop the running aggregator?");
+ });
+
+}
+
+if (document.readyState === "loading") {
+ document.addEventListener("DOMContentLoaded", init);
+} else {
+ init();
+}
\ No newline at end of file
diff --git a/cli/medperf/web_ui/static/js/aggregators/aggregator_register.js b/cli/medperf/web_ui/static/js/aggregators/aggregator_register.js
new file mode 100644
index 000000000..f05fe6a37
--- /dev/null
+++ b/cli/medperf/web_ui/static/js/aggregators/aggregator_register.js
@@ -0,0 +1,27 @@
+function checkAggregatorFormValidity() {
+ var nameEl = document.getElementById("name");
+ var addressEl = document.getElementById("address");
+ var portEl = document.getElementById("port");
+ var cubeEl = document.getElementById("aggregation-mlcube");
+ var nameValue = nameEl ? nameEl.value.trim() : "";
+ var addressValue = addressEl ? addressEl.value.trim() : "";
+ var portValue = portEl && portEl.value ? parseInt(portEl.value, 10) : 0;
+ var cubeValue = cubeEl && cubeEl.value ? parseInt(cubeEl.value, 10) : 0;
+ var isValid = nameValue.length > 0 && addressValue.length > 0 && portValue > 0 && portValue <= 65535 && cubeValue > 0;
+ var btn = document.getElementById("register-aggregator-btn");
+ if (btn) btn.disabled = !isValid;
+}
+
+function init() {
+ var form = document.getElementById("aggregator-register-form");
+ if (form) {
+ form.addEventListener("submit", submitActionForm);
+ form.querySelectorAll("input, select").forEach(function (el) {
+ el.addEventListener("keyup", checkAggregatorFormValidity);
+ el.addEventListener("change", checkAggregatorFormValidity);
+ });
+ }
+ checkAggregatorFormValidity();
+}
+if (document.readyState === "loading") document.addEventListener("DOMContentLoaded", init);
+else init();
diff --git a/cli/medperf/web_ui/static/js/assets/asset_register.js b/cli/medperf/web_ui/static/js/assets/asset_register.js
index 5e3f0c079..bd251275a 100644
--- a/cli/medperf/web_ui/static/js/assets/asset_register.js
+++ b/cli/medperf/web_ui/static/js/assets/asset_register.js
@@ -1,70 +1,49 @@
-function onAssetRegisterSuccess(response){
- markAllStagesAsComplete();
- if(response.status === "success"){
- showReloadModal({
- title: "Asset Registered Successfully",
- seconds: 3,
- url: "/assets/ui/display/"+response.asset_id
- });
- }
- else{
- showErrorModal("Failed to Register Asset", response);
- }
-}
-
-async function registerAsset(registerButton){
- addSpinner(registerButton);
-
- const formData = new FormData($("#asset-register-form")[0]);
-
- disableElements("#asset-register-form input, #asset-register-form button");
-
- ajaxRequest(
- "/assets/register",
- "POST",
- formData,
- onAssetRegisterSuccess,
- "Error registering asset:"
- )
-
- showPanel(`Registering Asset...`);
- window.runningTaskId = await getTaskId();
- streamEvents(logPanel, stagesList, currentStageElement);
-}
+var REDIRECT_BASE = "/assets/ui/display/";
function checkAssetFormValidity() {
- const assetURL = $("#asset-url").val().trim();
- const isRemote = $("input[name='asset_is_remote']:checked").val();
- const assetPath = $("#asset-path").val().trim();
-
- const isValid = Boolean(
- $("#name").val().trim() &&
- (isRemote === "true" ? assetURL.length > 0 : isRemote === "false" && assetPath.length > 0)
- );
- $("#register-asset-btn").prop("disabled", !isValid);
+ var nameVal = document.getElementById("name") ? document.getElementById("name").value.trim() : "";
+ var isRemote = document.querySelector("input[name='asset_is_remote']:checked");
+ var remoteVal = isRemote ? isRemote.value : "false";
+ var assetURL = document.getElementById("asset-url") ? document.getElementById("asset-url").value.trim() : "";
+ var assetPath = document.getElementById("asset-path") ? document.getElementById("asset-path").value.trim() : "";
+ var isValid = !!nameVal && (remoteVal === "true" ? assetURL.length > 0 : remoteVal === "false" && assetPath.length > 0);
+ var btn = document.getElementById("register-asset-btn");
+ if (btn) btn.disabled = !isValid;
}
-$(document).ready(() => {
- $("#register-asset-btn").on("click", (e) => {
- showConfirmModal(e.currentTarget, registerAsset, "register this asset?");
+function initAssetRegister() {
+ var form = document.getElementById("asset-register-form");
+ if (form) {
+ form.addEventListener("submit", submitActionForm);
+ form.querySelectorAll("input").forEach(function (el) {
+ el.addEventListener("keyup", checkAssetFormValidity);
+ el.addEventListener("change", checkAssetFormValidity);
+ });
+ }
+ var browseBtn = document.getElementById("browse-asset-btn");
+ if (browseBtn) browseBtn.addEventListener("click", function () { browseWithFiles = true; browseFolderHandler("asset-path"); });
+ document.querySelectorAll("input[name='asset_is_remote']").forEach(function (radio) {
+ radio.addEventListener("change", function () {
+ var urlContainer = document.getElementById("asset-url-container");
+ var pathContainer = document.getElementById("asset-path-container");
+ var assetUrlInput = document.getElementById("asset-url");
+ var assetPathInput = document.getElementById("asset-path");
+ if (this.value === "false") {
+ if (urlContainer) urlContainer.classList.add("hidden");
+ if (pathContainer) pathContainer.classList.remove("hidden");
+ if (assetUrlInput) assetUrlInput.value = "";
+ } else {
+ if (pathContainer) pathContainer.classList.add("hidden");
+ if (urlContainer) urlContainer.classList.remove("hidden");
+ if (assetPathInput) assetPathInput.value = "";
+ }
+ });
});
+ checkAssetFormValidity();
+}
- $("#asset-register-form input").on("keyup change", checkAssetFormValidity);
-
- $("#browse-asset-btn").on("click", () => {
- browseWithFiles = true;
- browseFolderHandler("asset-path");
- });
- $("input[name='asset_is_remote']").on("change", () => {
- if($("#local").is(":checked")){
- $("#asset-path-container").show();
- $("#asset-url-container").hide();
- $("#asset-url").val("");
- }
- else{
- $("#asset-url-container").show();
- $("#asset-path-container").hide();
- $("#asset-path").val("");
- }
- });
-});
+if (document.readyState === "loading") {
+ document.addEventListener("DOMContentLoaded", initAssetRegister);
+} else {
+ initAssetRegister();
+}
diff --git a/cli/medperf/web_ui/static/js/benchmarks/benchmark_detail.js b/cli/medperf/web_ui/static/js/benchmarks/benchmark_detail.js
index 67946df37..d4961cd0e 100644
--- a/cli/medperf/web_ui/static/js/benchmarks/benchmark_detail.js
+++ b/cli/medperf/web_ui/static/js/benchmarks/benchmark_detail.js
@@ -1,247 +1,154 @@
-
-function showConfirmationPrompt(approveRejectBtn){
- const entityType = approveRejectBtn.getAttribute("data-entity-type");
- const actionName = approveRejectBtn.getAttribute("data-action-name");
- const benchmarkId = approveRejectBtn.getAttribute("data-benchmark-id");
- const entityId = approveRejectBtn.getAttribute(`data-${entityType}-id`);
-
- let message = actionName + " this association?
This action cannot be undone.";
- const callback = () => {
- approveRejectAssociation(actionName, benchmarkId, entityId, entityType, approveRejectBtn);
- };
- showConfirmModal(approveRejectBtn, callback, message);
-}
-
-function onApproveRejectAssociationSuccess(response, actionName){
- let title;
- if(response.status === "success"){
- if(actionName === "approve")
- title = "Association Approved Successfully";
- else
- title = "Association Rejected Successfully";
- showReloadModal({
- title: title,
- seconds: 3
- });
- }
- else{
- if(actionName === "approve")
- title = "Failed to Approve Association";
- else
- title = "Failed to Reject Association";
- showErrorModal(title, response)
- }
-}
-
-async function approveRejectAssociation(actionName, benchmarkId, entityId, entityType, approveRejectBtn){
- addSpinner(approveRejectBtn);
- disableElements(".card button");
-
- const formData = new FormData();
- formData.append("benchmark_id", benchmarkId);
- formData.append(`${entityType}_id`, entityId);
-
- ajaxRequest(
- `/benchmarks/${actionName}`,
- "POST",
- formData,
- function(response) {
- onApproveRejectAssociationSuccess(response, actionName);
- },
- "Error approving/rejecting associtation"
- );
-
- window.runningTaskId = await getTaskId();
-}
+var REDIRECT_BASE = "/benchmarks/ui/display/";
function isValidEmail(email) {
return /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(email);
}
-function createEmailChip(email, input_element) {
- const $chip = $("
${results}" + (results || "").replace(/