From 0b9f655361832f85fa8f621f61b3e8324569c108 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 9 Mar 2026 00:52:41 +0100 Subject: [PATCH 01/72] some webui messages typos --- .../web_ui/static/js/containers/container_register.js | 4 ++-- .../web_ui/templates/container/register_container.html | 4 ++-- .../web_ui/tests/e2e/test_medperf_tutorial_workflow.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cli/medperf/web_ui/static/js/containers/container_register.js b/cli/medperf/web_ui/static/js/containers/container_register.js index f9914d5e2..eb7b51b3e 100644 --- a/cli/medperf/web_ui/static/js/containers/container_register.js +++ b/cli/medperf/web_ui/static/js/containers/container_register.js @@ -2,13 +2,13 @@ function onContainerRegisterSuccess(response){ markAllStagesAsComplete(); if(response.status === "success"){ showReloadModal({ - title: "Model Registered Successfully", + title: "Container Registered Successfully", seconds: 3, url: "/containers/ui/display/"+response.container_id }); } else{ - showErrorModal("Failed to Register Model", response); + showErrorModal("Failed to Register Container", response); } } diff --git a/cli/medperf/web_ui/templates/container/register_container.html b/cli/medperf/web_ui/templates/container/register_container.html index aecb16a60..e682e0a19 100644 --- a/cli/medperf/web_ui/templates/container/register_container.html +++ b/cli/medperf/web_ui/templates/container/register_container.html @@ -121,11 +121,11 @@

Register a New Container

- +
- +
diff --git a/cli/medperf/web_ui/tests/e2e/test_medperf_tutorial_workflow.py b/cli/medperf/web_ui/tests/e2e/test_medperf_tutorial_workflow.py index c23d9fb05..d98892b3e 100644 --- a/cli/medperf/web_ui/tests/e2e/test_medperf_tutorial_workflow.py +++ b/cli/medperf/web_ui/tests/e2e/test_medperf_tutorial_workflow.py @@ -129,7 +129,7 @@ def test_benchmark_register_data_prep_container(driver): while not page_modal.is_displayed(): time.sleep(0.2) - assert page.get_text(page.PAGE_MODAL_TITLE) == "Model Registered Successfully" + assert page.get_text(page.PAGE_MODAL_TITLE) == "Container Registered Successfully" page.wait_for_staleness_element(page_modal) page.wait_for_url_change(old_url) @@ -170,7 +170,7 @@ def test_benchmark_register_reference_model_container(driver): while not page_modal.is_displayed(): time.sleep(0.2) - assert page.get_text(page.PAGE_MODAL_TITLE) == "Model Registered Successfully" + assert page.get_text(page.PAGE_MODAL_TITLE) == "Container Registered Successfully" page.wait_for_staleness_element(page_modal) page.wait_for_url_change(old_url) @@ -211,7 +211,7 @@ def test_benchmark_register_metrics_container(driver): while not page_modal.is_displayed(): time.sleep(0.2) - assert page.get_text(page.PAGE_MODAL_TITLE) == "Model Registered Successfully" + assert page.get_text(page.PAGE_MODAL_TITLE) == "Container Registered Successfully" page.wait_for_staleness_element(page_modal) page.wait_for_url_change(old_url) @@ -582,7 +582,7 @@ def test_container_registration(driver): while not page_modal.is_displayed(): time.sleep(0.2) - assert page.get_text(page.PAGE_MODAL_TITLE) == "Model Registered Successfully" + assert page.get_text(page.PAGE_MODAL_TITLE) == "Container Registered Successfully" page.wait_for_staleness_element(page_modal) page.wait_for_url_change(old_url) @@ -676,7 +676,7 @@ def test_encrypted_container_registration(driver): while not page_modal.is_displayed(): time.sleep(0.2) - assert page.get_text(page.PAGE_MODAL_TITLE) == "Model Registered Successfully" + assert page.get_text(page.PAGE_MODAL_TITLE) == "Container Registered Successfully" page.wait_for_staleness_element(page_modal) page.wait_for_url_change(old_url) From 610bb25e99781ebfedac957088016574af79e41c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 9 Mar 2026 02:04:54 +0100 Subject: [PATCH 02/72] add association signing --- .../commands/association/association.py | 14 +++++++ cli/medperf/commands/association/sign.py | 30 +++++++++++++ cli/medperf/commands/association/utils.py | 30 +++++++++++++ cli/medperf/encryption.py | 42 ++++++++++++++++++- cli/medperf/entities/benchmark.py | 18 +++++--- .../0002_benchmarkdataset_signature.py | 18 ++++++++ server/benchmarkdataset/models.py | 1 + server/benchmarkdataset/serializers.py | 1 + server/benchmarkdataset/views.py | 2 + .../0004_benchmarkmodel_signature.py | 18 ++++++++ server/benchmarkmodel/models.py | 1 + server/benchmarkmodel/serializers.py | 1 + server/benchmarkmodel/views.py | 2 + 13 files changed, 172 insertions(+), 6 deletions(-) create mode 100644 cli/medperf/commands/association/sign.py create mode 100644 server/benchmarkdataset/migrations/0002_benchmarkdataset_signature.py create mode 100644 server/benchmarkmodel/migrations/0004_benchmarkmodel_signature.py diff --git a/cli/medperf/commands/association/association.py b/cli/medperf/commands/association/association.py index a6f52346a..9dc88c4f3 100644 --- a/cli/medperf/commands/association/association.py +++ b/cli/medperf/commands/association/association.py @@ -5,6 +5,7 @@ from medperf.commands.association.list import ListAssociations from medperf.commands.association.approval import Approval from medperf.commands.association.priority import AssociationPriority +from medperf.commands.association.sign import SignAssociations from medperf.enums import Status app = typer.Typer() @@ -122,3 +123,16 @@ def set_priority( """ AssociationPriority.run(benchmark_uid, model_uid, priority) config.ui.print("✅ Done!") + + +@app.command("sign") +@clean_except +def sign( + benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), +): + """Sign all associations related to a specific benchmark. + + Args: + benchmark_uid (int): Benchmark UID. + """ + SignAssociations.run(benchmark_uid) diff --git a/cli/medperf/commands/association/sign.py b/cli/medperf/commands/association/sign.py new file mode 100644 index 000000000..2256b91a2 --- /dev/null +++ b/cli/medperf/commands/association/sign.py @@ -0,0 +1,30 @@ +from medperf.entities.benchmark import Benchmark +from medperf.commands.association.utils import ( + sign_dataset_association, + sign_model_association, +) +from medperf.account_management.account_management import get_medperf_user_data +from medperf.exceptions import MedperfException + + +class SignAssociations: + @staticmethod + def run(benchmark_uid: int): + benchmark = Benchmark.get(benchmark_uid) + current_user_id = get_medperf_user_data()["id"] + if benchmark.owner != current_user_id: + raise MedperfException( + "You are not the owner of this benchmark and cannot sign its associations." + ) + data_assocs = Benchmark.get_datasets_associations( + benchmark_uid=benchmark_uid, approval_status="approved" + ) + model_assocs = Benchmark.get_models_associations( + benchmark_uid=benchmark_uid, approval_status="approved" + ) + + for assoc in data_assocs: + sign_dataset_association(assoc) + + for assoc in model_assocs: + sign_model_association(assoc) diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py index 80a5ffbea..9588c82e1 100644 --- a/cli/medperf/commands/association/utils.py +++ b/cli/medperf/commands/association/utils.py @@ -1,6 +1,10 @@ from medperf.exceptions import InvalidArgumentError, MedperfException from medperf import config from pydantic.datetime_parse import parse_datetime +from medperf.entities.dataset import Dataset +from medperf.entities.model import Model +from medperf.encryption import Signing +from medperf.commands.certificate.utils import load_user_private_key def validate_args(benchmark, training_exp, dataset, model, aggregator, approval_status): @@ -151,3 +155,29 @@ def _post_process_associtations( ] return assocs + + +def sign_dataset_association(assoc): + dataset_id = assoc["dataset"] + benchmak_id = assoc["benchmark"] + dataset = Dataset.get(dataset_id) + dataset_hash = dataset.generated_uid + user_private_key = load_user_private_key() + signature = Signing().sign_prehashed(user_private_key, dataset_hash) + body = {"signature": signature} + config.comms.update_benchmark_dataset_association(benchmak_id, dataset_id, body) + + +def sign_model_association(assoc): + model_id = assoc["model"] + benchmark_id = assoc["benchmark"] + model = Model.get(model_id) + if model.is_container(): + model_hash = model.container.image_hash + model_hash = model_hash.replace("sha256:", "") + else: + model_hash = model.asset.asset_hash + user_private_key = load_user_private_key() + signature = Signing().sign_prehashed(user_private_key, model_hash) + body = {"signature": signature} + config.comms.update_benchmark_model_association(benchmark_id, model_id, body) diff --git a/cli/medperf/encryption.py b/cli/medperf/encryption.py index 899399417..d7ad52574 100644 --- a/cli/medperf/encryption.py +++ b/cli/medperf/encryption.py @@ -1,3 +1,4 @@ +import binascii import os from medperf.exceptions import ( DecryptionError, @@ -6,9 +7,10 @@ MedperfException, ) from medperf.utils import run_command -from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric import padding, utils from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization +from cryptography.exceptions import InvalidSignature from cryptography import x509 import logging @@ -111,3 +113,41 @@ def decrypt(self, private_key_bytes: bytes, encrypted_data_bytes: bytes) -> byte return data_bytes except Exception as e: raise DecryptionError(f"Data decryption failed: {str(e)}") + + +class Signing: + def __init__(self): + self.padding = padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ) + + def sign_prehashed(self, private_key_bytes: bytes, data_hash_hex: str) -> bytes: + logging.debug("Performing Asymmetric Signing") + try: + private_key = serialization.load_pem_private_key( + data=private_key_bytes, password=None + ) + data_hash = binascii.unhexlify(data_hash_hex) + signature = private_key.sign( + data_hash, self.padding, utils.Prehashed(hashes.SHA256()) + ) + return signature + except Exception as e: + raise EncryptionError(f"Data signing failed: {str(e)}") + + def verify_prehashed( + self, public_key_bytes: bytes, data_hash_hex: str, signature: bytes + ) -> bool: + logging.debug("Performing Asymmetric Signature Verification") + try: + public_key_obj = serialization.load_pem_public_key(public_key_bytes) + data_hash = binascii.unhexlify(data_hash_hex) + public_key_obj.verify( + signature, data_hash, self.padding, utils.Prehashed(hashes.SHA256()) + ) + return True + except InvalidSignature: + return False + except Exception as e: + raise EncryptionError(f"Signature verification failed: {str(e)}") diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index da38f66fa..e2e734e37 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -124,11 +124,15 @@ def get_datasets_with_users(cls, benchmark_uid: int) -> List[dict]: return uids_with_users @classmethod - def get_models_associations(cls, benchmark_uid: int) -> List[dict]: + def get_models_associations( + cls, benchmark_uid: int, approval_status: str = None + ) -> List[dict]: """Retrieves the list of model associations to the benchmark Args: benchmark_uid (int): UID of the benchmark. + approval_status (str, optional): Filter associations by approval status. + Defaults to None, which retrieves all associations. Returns: List[dict]: List of associations @@ -140,7 +144,7 @@ def get_models_associations(cls, benchmark_uid: int) -> List[dict]: associations = get_user_associations( experiment_type=experiment_type, component_type=component_type, - approval_status=None, + approval_status=approval_status, ) associations = [a for a in associations if a["benchmark"] == benchmark_uid] @@ -148,11 +152,15 @@ def get_models_associations(cls, benchmark_uid: int) -> List[dict]: return associations @classmethod - def get_datasets_associations(cls, benchmark_uid: int) -> List[dict]: - """Retrieves the list of models associated to the benchmark + def get_datasets_associations( + cls, benchmark_uid: int, approval_status: str = None + ) -> List[dict]: + """Retrieves the list of datasets associated to the benchmark Args: benchmark_uid (int): UID of the benchmark. + approval_status (str, optional): Filter associations by approval status. + Defaults to None, which retrieves all associations. Returns: List[dict]: List of associations @@ -164,7 +172,7 @@ def get_datasets_associations(cls, benchmark_uid: int) -> List[dict]: associations = get_user_associations( experiment_type=experiment_type, component_type=component_type, - approval_status=None, # TODO + approval_status=approval_status, ) associations = [a for a in associations if a["benchmark"] == benchmark_uid] diff --git a/server/benchmarkdataset/migrations/0002_benchmarkdataset_signature.py b/server/benchmarkdataset/migrations/0002_benchmarkdataset_signature.py new file mode 100644 index 000000000..979487ba3 --- /dev/null +++ b/server/benchmarkdataset/migrations/0002_benchmarkdataset_signature.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.23 on 2026-03-09 00:45 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('benchmarkdataset', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='benchmarkdataset', + name='signature', + field=models.CharField(blank=True, max_length=1000, null=True), + ), + ] diff --git a/server/benchmarkdataset/models.py b/server/benchmarkdataset/models.py index a5209085e..4318d629d 100644 --- a/server/benchmarkdataset/models.py +++ b/server/benchmarkdataset/models.py @@ -20,6 +20,7 @@ class BenchmarkDataset(models.Model): approved_at = models.DateTimeField(null=True, blank=True) created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) + signature = models.CharField(max_length=1000, null=True, blank=True) class Meta: ordering = ["modified_at"] diff --git a/server/benchmarkdataset/serializers.py b/server/benchmarkdataset/serializers.py index 6ba22a4bb..fca3df923 100644 --- a/server/benchmarkdataset/serializers.py +++ b/server/benchmarkdataset/serializers.py @@ -81,6 +81,7 @@ class Meta: "approved_at", "created_at", "modified_at", + "signature", ] def validate(self, data): diff --git a/server/benchmarkdataset/views.py b/server/benchmarkdataset/views.py index 3ac96ce22..9588c88ce 100644 --- a/server/benchmarkdataset/views.py +++ b/server/benchmarkdataset/views.py @@ -61,6 +61,8 @@ def get_permissions(self): self.permission_classes = [IsAdmin | IsBenchmarkOwner | IsDatasetOwner] if self.request.method == "DELETE": self.permission_classes = [IsAdmin] + elif self.request.method == "PUT" and "signature" in self.request.data: + self.permission_classes = [IsAdmin | IsBenchmarkOwner] return super(self.__class__, self).get_permissions() def get_object(self, dataset_id, benchmark_id): diff --git a/server/benchmarkmodel/migrations/0004_benchmarkmodel_signature.py b/server/benchmarkmodel/migrations/0004_benchmarkmodel_signature.py new file mode 100644 index 000000000..b95fd6e5f --- /dev/null +++ b/server/benchmarkmodel/migrations/0004_benchmarkmodel_signature.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.23 on 2026-03-09 00:45 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('benchmarkmodel', '0003_remove_benchmarkmodel_model_mlcube_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='benchmarkmodel', + name='signature', + field=models.CharField(blank=True, max_length=1000, null=True), + ), + ] diff --git a/server/benchmarkmodel/models.py b/server/benchmarkmodel/models.py index fd76eadd8..4bfafd20f 100644 --- a/server/benchmarkmodel/models.py +++ b/server/benchmarkmodel/models.py @@ -21,6 +21,7 @@ class BenchmarkModel(models.Model): created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) priority = models.IntegerField(default=0) + signature = models.CharField(max_length=1000, null=True, blank=True) class Meta: ordering = ["-priority"] diff --git a/server/benchmarkmodel/serializers.py b/server/benchmarkmodel/serializers.py index ba70481ee..b16608325 100644 --- a/server/benchmarkmodel/serializers.py +++ b/server/benchmarkmodel/serializers.py @@ -75,6 +75,7 @@ class Meta: "created_at", "modified_at", "priority", + "signature", ] def validate(self, data): diff --git a/server/benchmarkmodel/views.py b/server/benchmarkmodel/views.py index fa3bec6e4..7f1303489 100644 --- a/server/benchmarkmodel/views.py +++ b/server/benchmarkmodel/views.py @@ -61,6 +61,8 @@ def get_permissions(self): self.permission_classes = [IsAdmin | IsBenchmarkOwner | IsModelOwner] if self.request.method == "PUT" and "priority" in self.request.data: self.permission_classes = [IsAdmin | IsBenchmarkOwner] + elif self.request.method == "PUT" and "signature" in self.request.data: + self.permission_classes = [IsAdmin | IsBenchmarkOwner] elif self.request.method == "DELETE": self.permission_classes = [IsAdmin] return super(self.__class__, self).get_permissions() From 1bcd289279c6f53ca501c98aa7f11f836f5e75be Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 9 Mar 2026 20:48:30 +0100 Subject: [PATCH 03/72] fix circular imports for signing assocs --- cli/medperf/commands/association/sign.py | 35 ++++++++++++++++++++--- cli/medperf/commands/association/utils.py | 30 ------------------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/cli/medperf/commands/association/sign.py b/cli/medperf/commands/association/sign.py index 2256b91a2..0079871a8 100644 --- a/cli/medperf/commands/association/sign.py +++ b/cli/medperf/commands/association/sign.py @@ -1,10 +1,37 @@ from medperf.entities.benchmark import Benchmark -from medperf.commands.association.utils import ( - sign_dataset_association, - sign_model_association, -) +from medperf.entities.dataset import Dataset +from medperf.entities.model import Model +from medperf.encryption import Signing +from medperf.commands.certificate.utils import load_user_private_key from medperf.account_management.account_management import get_medperf_user_data from medperf.exceptions import MedperfException +from medperf import config + + +def sign_dataset_association(assoc): + dataset_id = assoc["dataset"] + benchmak_id = assoc["benchmark"] + dataset = Dataset.get(dataset_id) + dataset_hash = dataset.generated_uid + user_private_key = load_user_private_key() + signature = Signing().sign_prehashed(user_private_key, dataset_hash) + body = {"signature": signature} + config.comms.update_benchmark_dataset_association(benchmak_id, dataset_id, body) + + +def sign_model_association(assoc): + model_id = assoc["model"] + benchmark_id = assoc["benchmark"] + model = Model.get(model_id) + if model.is_container(): + model_hash = model.container.image_hash + model_hash = model_hash.replace("sha256:", "") + else: + model_hash = model.asset.asset_hash + user_private_key = load_user_private_key() + signature = Signing().sign_prehashed(user_private_key, model_hash) + body = {"signature": signature} + config.comms.update_benchmark_model_association(benchmark_id, model_id, body) class SignAssociations: diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py index 9588c82e1..80a5ffbea 100644 --- a/cli/medperf/commands/association/utils.py +++ b/cli/medperf/commands/association/utils.py @@ -1,10 +1,6 @@ from medperf.exceptions import InvalidArgumentError, MedperfException from medperf import config from pydantic.datetime_parse import parse_datetime -from medperf.entities.dataset import Dataset -from medperf.entities.model import Model -from medperf.encryption import Signing -from medperf.commands.certificate.utils import load_user_private_key def validate_args(benchmark, training_exp, dataset, model, aggregator, approval_status): @@ -155,29 +151,3 @@ def _post_process_associtations( ] return assocs - - -def sign_dataset_association(assoc): - dataset_id = assoc["dataset"] - benchmak_id = assoc["benchmark"] - dataset = Dataset.get(dataset_id) - dataset_hash = dataset.generated_uid - user_private_key = load_user_private_key() - signature = Signing().sign_prehashed(user_private_key, dataset_hash) - body = {"signature": signature} - config.comms.update_benchmark_dataset_association(benchmak_id, dataset_id, body) - - -def sign_model_association(assoc): - model_id = assoc["model"] - benchmark_id = assoc["benchmark"] - model = Model.get(model_id) - if model.is_container(): - model_hash = model.container.image_hash - model_hash = model_hash.replace("sha256:", "") - else: - model_hash = model.asset.asset_hash - user_private_key = load_user_private_key() - signature = Signing().sign_prehashed(user_private_key, model_hash) - body = {"signature": signature} - config.comms.update_benchmark_model_association(benchmark_id, model_id, body) From 7ed505ba3b498885d4ccbafdec8cc37f1928c920 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 9 Mar 2026 20:51:29 +0100 Subject: [PATCH 04/72] better apis for cc --- .../asset_management/asset_management.py | 37 ++++++------------- cli/medperf/commands/cc/cc.py | 6 +-- .../commands/cc/dataset_configure_for_cc.py | 13 +++---- .../commands/cc/model_configure_for_cc.py | 12 +++--- cli/medperf/commands/cc/setup_cc_operator.py | 12 +++--- cli/medperf/commands/execution/create.py | 4 +- .../commands/execution/execution_flow.py | 2 +- cli/medperf/entities/dataset.py | 16 ++------ cli/medperf/entities/model.py | 20 +++------- cli/medperf/entities/user.py | 14 +------ 10 files changed, 46 insertions(+), 90 deletions(-) diff --git a/cli/medperf/asset_management/asset_management.py b/cli/medperf/asset_management/asset_management.py index 84ec0d26d..5d2e2e00b 100644 --- a/cli/medperf/asset_management/asset_management.py +++ b/cli/medperf/asset_management/asset_management.py @@ -20,16 +20,11 @@ def generate_encryption_key(encryption_key_file: str): def setup_dataset_for_cc(dataset: Dataset): + if not dataset.is_cc_configured(): + return cc_config = dataset.get_cc_config() cc_policy = dataset.get_cc_policy() - if not cc_config: - raise ValueError( - f"Dataset {dataset.id} does not have a configuration for confidential computing." - ) - if cc_policy is None: - raise ValueError( - f"Dataset {dataset.id} does not have a policy for confidential computing." - ) + # create dataset asset asset_path = generate_tmp_path() tar(asset_path, [dataset.data_path, dataset.labels_path]) @@ -46,16 +41,10 @@ def setup_dataset_for_cc(dataset: Dataset): def setup_model_for_cc(model: Model): + if not model.is_cc_configured(): + return cc_config = model.get_cc_config() cc_policy = model.get_cc_policy() - if not cc_config: - raise ValueError( - f"Model {model.id} does not have a configuration for confidential computing." - ) - if cc_policy is None: - raise ValueError( - f"Model {model.id} does not have a policy for confidential computing." - ) if model.type != "ASSET": raise ValueError( f"Model {model.id} is not a file-based asset and cannot be set up for confidential computing." @@ -93,12 +82,12 @@ def __setup_asset_for_cc( def update_dataset_cc_policy(dataset: Dataset, permitted_workloads: list[CCWorkloadID]): - cc_config = dataset.get_cc_config() - if not cc_config: + if not dataset.is_cc_configured(): raise ValueError( f"Dataset {dataset.id} does not have a configuration for confidential computing." ) + cc_config = dataset.get_cc_config() encryption_key_folder = os.path.join( config.cc_artifacts_dir, "dataset" + str(dataset.id) ) @@ -109,11 +98,11 @@ def update_dataset_cc_policy(dataset: Dataset, permitted_workloads: list[CCWorkl def update_model_cc_policy(model: Model, permitted_workloads: list[CCWorkloadID]): - cc_config = model.get_cc_config() - if not cc_config: + if not model.is_cc_configured(): raise ValueError( f"Model {model.id} does not have a configuration for confidential computing." ) + cc_config = model.get_cc_config() if model.type != "ASSET": raise ValueError( f"Model {model.id} is not a file-based asset and cannot be set up for confidential computing." @@ -129,12 +118,10 @@ def update_model_cc_policy(model: Model, permitted_workloads: list[CCWorkloadID] def setup_operator(user: User): - cc_config = user.get_cc_config() - if not cc_config: - raise ValueError( - "User does not have a configuration for confidential computing." - ) + if not user.is_cc_configured(): + return + cc_config = user.get_cc_config() operator_manager = OperatorManager(cc_config) operator_manager.setup() diff --git a/cli/medperf/commands/cc/cc.py b/cli/medperf/commands/cc/cc.py index 4e218f631..e6c4f437b 100644 --- a/cli/medperf/commands/cc/cc.py +++ b/cli/medperf/commands/cc/cc.py @@ -24,7 +24,7 @@ def configure_dataset_for_cc( ): """Configure dataset for confidential computing execution""" ui = config.ui - DatasetConfigureForCC.run(data_uid, cc_config_file, cc_policy_file) + DatasetConfigureForCC.run_from_files(data_uid, cc_config_file, cc_policy_file) ui.print("✅ Done!") @@ -41,7 +41,7 @@ def configure_model_for_cc( ): """Configure model for confidential computing execution""" ui = config.ui - ModelConfigureForCC.run(model_uid, cc_config_file, cc_policy_file) + ModelConfigureForCC.run_from_files(model_uid, cc_config_file, cc_policy_file) ui.print("✅ Done!") @@ -76,5 +76,5 @@ def setup_cc_operator( ): """Setup confidential computing operator""" ui = config.ui - SetupCCOperator.run(cc_config_file) + SetupCCOperator.run_from_files(cc_config_file) ui.print("✅ Done!") diff --git a/cli/medperf/commands/cc/dataset_configure_for_cc.py b/cli/medperf/commands/cc/dataset_configure_for_cc.py index e2262902f..4f7fd2cb7 100644 --- a/cli/medperf/commands/cc/dataset_configure_for_cc.py +++ b/cli/medperf/commands/cc/dataset_configure_for_cc.py @@ -6,19 +6,18 @@ class DatasetConfigureForCC: @classmethod - def run(cls, data_uid: int, cc_config_file: str, cc_policy_file: str): - dataset = Dataset.get(data_uid) + def run_from_files(cls, data_uid: int, cc_config_file: str, cc_policy_file: str): with open(cc_config_file) as f: cc_config = json.load(f) with open(cc_policy_file) as f: cc_policy = json.load(f) + cls.run(data_uid, cc_config, cc_policy) + + @classmethod + def run(cls, data_uid: int, cc_config: dict, cc_policy: dict): + dataset = Dataset.get(data_uid) dataset.set_cc_config(cc_config) dataset.set_cc_policy(cc_policy) - body = {"user_metadata": dataset.user_metadata} - config.comms.update_dataset(dataset.id, body) setup_dataset_for_cc(dataset) - - # mark as set - dataset.mark_cc_configured() body = {"user_metadata": dataset.user_metadata} config.comms.update_dataset(dataset.id, body) diff --git a/cli/medperf/commands/cc/model_configure_for_cc.py b/cli/medperf/commands/cc/model_configure_for_cc.py index 916c01cf8..1c9e6e0d5 100644 --- a/cli/medperf/commands/cc/model_configure_for_cc.py +++ b/cli/medperf/commands/cc/model_configure_for_cc.py @@ -6,18 +6,18 @@ class ModelConfigureForCC: @classmethod - def run(cls, model_uid: int, cc_config_file: str, cc_policy_file: str): - model = Model.get(model_uid) + def run_from_files(cls, model_uid: int, cc_config_file: str, cc_policy_file: str): with open(cc_config_file) as f: cc_config = json.load(f) with open(cc_policy_file) as f: cc_policy = json.load(f) + cls.run(model_uid, cc_config, cc_policy) + + @classmethod + def run(cls, model_uid: int, cc_config: dict, cc_policy: dict): + model = Model.get(model_uid) model.set_cc_config(cc_config) model.set_cc_policy(cc_policy) - body = {"user_metadata": model.user_metadata} - config.comms.update_model(model.id, body) setup_model_for_cc(model) - # mark as set - model.mark_cc_configured() body = {"user_metadata": model.user_metadata} config.comms.update_model(model.id, body) diff --git a/cli/medperf/commands/cc/setup_cc_operator.py b/cli/medperf/commands/cc/setup_cc_operator.py index 259d2360a..93cd7e3fd 100644 --- a/cli/medperf/commands/cc/setup_cc_operator.py +++ b/cli/medperf/commands/cc/setup_cc_operator.py @@ -6,17 +6,15 @@ class SetupCCOperator: @classmethod - def run(cls, cc_config_file: str): - user = get_medperf_user_object() + def run_from_files(cls, cc_config_file: str): with open(cc_config_file) as f: cc_config = json.load(f) + cls.run(cc_config) + @classmethod + def run(cls, cc_config: dict): + user = get_medperf_user_object() user.set_cc_config(cc_config) - body = {"metadata": user.metadata} - config.comms.update_user(user.id, body) setup_operator(user) - - # mark as set - user.mark_cc_configured() body = {"metadata": user.metadata} config.comms.update_user(user.id, body) diff --git a/cli/medperf/commands/execution/create.py b/cli/medperf/commands/execution/create.py index dbaf47b40..05e725056 100644 --- a/cli/medperf/commands/execution/create.py +++ b/cli/medperf/commands/execution/create.py @@ -237,7 +237,7 @@ def run_experiments(self) -> list[Execution]: "cached": False, "error": str(e), "partial": "N/A", - "confidential": model.is_cc_mode(), + "confidential": model.requires_cc(), } ) continue @@ -254,7 +254,7 @@ def run_experiments(self) -> list[Execution]: "cached": False, "error": "", "partial": execution_summary["partial"], - "confidential": model.is_cc_mode(), + "confidential": model.requires_cc(), } ) return [experiment["execution"] for experiment in self.experiments] diff --git a/cli/medperf/commands/execution/execution_flow.py b/cli/medperf/commands/execution/execution_flow.py index 35120a79f..867b746f3 100644 --- a/cli/medperf/commands/execution/execution_flow.py +++ b/cli/medperf/commands/execution/execution_flow.py @@ -26,7 +26,7 @@ def run( if ( model.type == ModelType.ASSET.value - and model.is_cc_mode() + and model.requires_cc() and not user_is_model_owner ): return ConfidentialExecution.run( diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 27a56287e..76bfa738d 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -8,7 +8,6 @@ from medperf.entities.schemas import DatasetSchema import medperf.config as config from medperf.account_management import get_medperf_user_data -from medperf.exceptions import MedperfException from medperf.entities.utils import handle_validation_error @@ -73,7 +72,7 @@ def local_id(self): def get_cc_config(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("config", None) + return cc_values.get("config", {}) def set_cc_config(self, cc_config: dict): if "cc" not in self.user_metadata: @@ -82,24 +81,15 @@ def set_cc_config(self, cc_config: dict): def get_cc_policy(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("policy", None) + return cc_values.get("policy", {}) def set_cc_policy(self, cc_policy: dict): if "cc" not in self.user_metadata: self.user_metadata["cc"] = {} self.user_metadata["cc"]["policy"] = cc_policy - def mark_cc_configured(self): - if "cc" not in self.user_metadata: - raise MedperfException( - "Dataset does not have a cc configuration to be marked as configured" - ) - self.user_metadata["cc"]["configured"] = True - def is_cc_configured(self): - if "cc" not in self.user_metadata: - return False - return self.user_metadata["cc"].get("configured", False) + return self.get_cc_config() != {} def set_raw_paths(self, raw_data_path: str, raw_labels_path: str): raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) diff --git a/cli/medperf/entities/model.py b/cli/medperf/entities/model.py index d4ddcadbc..f92f0a4e5 100644 --- a/cli/medperf/entities/model.py +++ b/cli/medperf/entities/model.py @@ -76,12 +76,13 @@ def asset_obj(self): def is_encrypted(self) -> bool: return self.is_container() and self.container_obj.is_encrypted() - def is_cc_mode(self): - return "cc" in self.user_metadata + def requires_cc(self): + # for now, let's do this + return self.is_asset() and self.asset_obj.is_local() def get_cc_config(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("config", None) + return cc_values.get("config", {}) def set_cc_config(self, cc_config: dict): if "cc" not in self.user_metadata: @@ -90,24 +91,15 @@ def set_cc_config(self, cc_config: dict): def get_cc_policy(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("policy", None) + return cc_values.get("policy", {}) def set_cc_policy(self, cc_policy: dict): if "cc" not in self.user_metadata: self.user_metadata["cc"] = {} self.user_metadata["cc"]["policy"] = cc_policy - def mark_cc_configured(self): - if "cc" not in self.user_metadata: - raise MedperfException( - "Model does not have a cc configuration to be marked as configured" - ) - self.user_metadata["cc"]["configured"] = True - def is_cc_configured(self): - if "cc" not in self.user_metadata: - return False - return self.user_metadata["cc"].get("configured", False) + return self.get_cc_config() != {} @staticmethod def remote_prefilter(filters: dict) -> callable: diff --git a/cli/medperf/entities/user.py b/cli/medperf/entities/user.py index 49ce1fd97..d7b836907 100644 --- a/cli/medperf/entities/user.py +++ b/cli/medperf/entities/user.py @@ -1,5 +1,4 @@ from medperf.entities.schemas import UserSchema -from medperf.exceptions import MedperfException from medperf.entities.utils import handle_validation_error @@ -31,21 +30,12 @@ def __setattr__(self, name, value): def get_cc_config(self): cc_values = self.metadata.get("cc", {}) - return cc_values.get("config", None) + return cc_values.get("config", {}) def set_cc_config(self, cc_config: dict): if "cc" not in self.metadata: self.metadata["cc"] = {} self.metadata["cc"]["config"] = cc_config - def mark_cc_configured(self): - if "cc" not in self.metadata: - raise MedperfException( - "User does not have a cc configuration to be marked as configured" - ) - self.metadata["cc"]["configured"] = True - def is_cc_configured(self): - if "cc" not in self.metadata: - return False - return self.metadata["cc"].get("configured", False) + return self.get_cc_config() != {} From 374002b6918e06df9abfc5a5e34d466447f8394b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 9 Mar 2026 21:05:43 +0100 Subject: [PATCH 05/72] validation for input cc config --- .../asset_management/asset_management.py | 21 ++++++++++++++++++- cli/medperf/asset_management/gcp_utils.py | 12 +++++------ .../commands/cc/dataset_configure_for_cc.py | 6 +++++- .../commands/cc/model_configure_for_cc.py | 6 +++++- cli/medperf/commands/cc/setup_cc_operator.py | 6 +++++- examples/cc/chestxray/dataset_cc_config.json | 2 -- examples/cc/chestxray/model_cc_config.json | 2 -- examples/cc/rano/dataset_cc_config.json | 2 -- examples/cc/rano/model_cc_config.json | 2 -- 9 files changed, 40 insertions(+), 19 deletions(-) diff --git a/cli/medperf/asset_management/asset_management.py b/cli/medperf/asset_management/asset_management.py index 5d2e2e00b..2e2092af3 100644 --- a/cli/medperf/asset_management/asset_management.py +++ b/cli/medperf/asset_management/asset_management.py @@ -1,4 +1,8 @@ -from medperf.asset_management.gcp_utils import CCWorkloadID +from medperf.asset_management.gcp_utils import ( + CCWorkloadID, + GCPAssetConfig, + GCPOperatorConfig, +) from medperf.entities.dataset import Dataset from medperf.entities.model import Model from medperf.entities.user import User @@ -19,6 +23,21 @@ def generate_encryption_key(encryption_key_file: str): f.write(secrets.token_bytes(32)) +def validate_cc_config(cc_config: dict, asset_name_prefix: str): + if cc_config == {}: + return + + cc_config["encrypted_asset_bucket_file"] = asset_name_prefix + ".enc" + cc_config["encrypted_key_bucket_file"] = asset_name_prefix + "_key.enc" + GCPAssetConfig(**cc_config) + + +def validate_cc_operator_config(cc_config: dict): + if cc_config == {}: + return + GCPOperatorConfig(**cc_config) + + def setup_dataset_for_cc(dataset: Dataset): if not dataset.is_cc_configured(): return diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/gcp_utils.py index 37f80c4d5..88df0a803 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/gcp_utils.py @@ -4,7 +4,6 @@ from typing import Union from medperf.exceptions import ExecutionError from medperf.utils import run_command -from dataclasses import dataclass from google.cloud import kms from google.iam.v1 import policy_pb2 from google.cloud import storage @@ -13,12 +12,13 @@ import time from colorama import Fore, Style import medperf.config as medperf_config +from pydantic import BaseModel GCP_EXEC = "gcloud" -@dataclass -class CCWorkloadID: +# TODO: validation of inputs +class CCWorkloadID(BaseModel): data_hash: str model_hash: str script_hash: str @@ -67,8 +67,7 @@ def results_encryption_key_path(self): return f"{self.human_readable_id}/encryption_key" -@dataclass -class GCPOperatorConfig: +class GCPOperatorConfig(BaseModel): project_id: str service_account_name: str account: str @@ -88,8 +87,7 @@ def service_account_email(self): return f"{self.service_account_name}@{self.project_id}.iam.gserviceaccount.com" -@dataclass -class GCPAssetConfig: +class GCPAssetConfig(BaseModel): project_id: str project_number: str account: str diff --git a/cli/medperf/commands/cc/dataset_configure_for_cc.py b/cli/medperf/commands/cc/dataset_configure_for_cc.py index 4f7fd2cb7..bea51f72d 100644 --- a/cli/medperf/commands/cc/dataset_configure_for_cc.py +++ b/cli/medperf/commands/cc/dataset_configure_for_cc.py @@ -1,5 +1,8 @@ from medperf.entities.dataset import Dataset -from medperf.asset_management.asset_management import setup_dataset_for_cc +from medperf.asset_management.asset_management import ( + setup_dataset_for_cc, + validate_cc_config, +) import json from medperf import config @@ -15,6 +18,7 @@ def run_from_files(cls, data_uid: int, cc_config_file: str, cc_policy_file: str) @classmethod def run(cls, data_uid: int, cc_config: dict, cc_policy: dict): + validate_cc_config(cc_config, "dataset" + str(data_uid)) dataset = Dataset.get(data_uid) dataset.set_cc_config(cc_config) dataset.set_cc_policy(cc_policy) diff --git a/cli/medperf/commands/cc/model_configure_for_cc.py b/cli/medperf/commands/cc/model_configure_for_cc.py index 1c9e6e0d5..9f853ab8c 100644 --- a/cli/medperf/commands/cc/model_configure_for_cc.py +++ b/cli/medperf/commands/cc/model_configure_for_cc.py @@ -1,5 +1,8 @@ from medperf.entities.model import Model -from medperf.asset_management.asset_management import setup_model_for_cc +from medperf.asset_management.asset_management import ( + setup_model_for_cc, + validate_cc_config, +) import json from medperf import config @@ -15,6 +18,7 @@ def run_from_files(cls, model_uid: int, cc_config_file: str, cc_policy_file: str @classmethod def run(cls, model_uid: int, cc_config: dict, cc_policy: dict): + validate_cc_config(cc_config, "model" + str(model_uid)) model = Model.get(model_uid) model.set_cc_config(cc_config) model.set_cc_policy(cc_policy) diff --git a/cli/medperf/commands/cc/setup_cc_operator.py b/cli/medperf/commands/cc/setup_cc_operator.py index 93cd7e3fd..cf7e901ea 100644 --- a/cli/medperf/commands/cc/setup_cc_operator.py +++ b/cli/medperf/commands/cc/setup_cc_operator.py @@ -1,5 +1,8 @@ import json -from medperf.asset_management.asset_management import setup_operator +from medperf.asset_management.asset_management import ( + setup_operator, + validate_cc_operator_config, +) from medperf.account_management import get_medperf_user_object from medperf import config @@ -13,6 +16,7 @@ def run_from_files(cls, cc_config_file: str): @classmethod def run(cls, cc_config: dict): + validate_cc_operator_config(cc_config) user = get_medperf_user_object() user.set_cc_config(cc_config) setup_operator(user) diff --git a/examples/cc/chestxray/dataset_cc_config.json b/examples/cc/chestxray/dataset_cc_config.json index 3b7851399..5a81e394b 100644 --- a/examples/cc/chestxray/dataset_cc_config.json +++ b/examples/cc/chestxray/dataset_cc_config.json @@ -3,8 +3,6 @@ "project_number": "819939352708", "account": "hasan.kassem@mlcommons.org", "bucket": "data-owner-bucket-medperf-cc", - "encrypted_asset_bucket_file": "data.enc", - "encrypted_key_bucket_file": "key.enc", "keyring_name": "data-owner-keyring", "key_name": "data-owner-key2", "wip": "data-owner-wip" diff --git a/examples/cc/chestxray/model_cc_config.json b/examples/cc/chestxray/model_cc_config.json index 0c7b5b530..1ba002a41 100644 --- a/examples/cc/chestxray/model_cc_config.json +++ b/examples/cc/chestxray/model_cc_config.json @@ -3,8 +3,6 @@ "project_number": "819939352708", "account": "hasan.kassem@mlcommons.org", "bucket": "model-owner-bucket-medperf", - "encrypted_asset_bucket_file": "model.enc", - "encrypted_key_bucket_file": "key.enc", "keyring_name": "model-owner-keyring", "key_name": "model-owner-key", "wip": "model-owner-wip" diff --git a/examples/cc/rano/dataset_cc_config.json b/examples/cc/rano/dataset_cc_config.json index 3b7851399..5a81e394b 100644 --- a/examples/cc/rano/dataset_cc_config.json +++ b/examples/cc/rano/dataset_cc_config.json @@ -3,8 +3,6 @@ "project_number": "819939352708", "account": "hasan.kassem@mlcommons.org", "bucket": "data-owner-bucket-medperf-cc", - "encrypted_asset_bucket_file": "data.enc", - "encrypted_key_bucket_file": "key.enc", "keyring_name": "data-owner-keyring", "key_name": "data-owner-key2", "wip": "data-owner-wip" diff --git a/examples/cc/rano/model_cc_config.json b/examples/cc/rano/model_cc_config.json index 0c7b5b530..1ba002a41 100644 --- a/examples/cc/rano/model_cc_config.json +++ b/examples/cc/rano/model_cc_config.json @@ -3,8 +3,6 @@ "project_number": "819939352708", "account": "hasan.kassem@mlcommons.org", "bucket": "model-owner-bucket-medperf", - "encrypted_asset_bucket_file": "model.enc", - "encrypted_key_bucket_file": "key.enc", "keyring_name": "model-owner-keyring", "key_name": "model-owner-key", "wip": "model-owner-wip" From d2a949a99344beba1fae2f10e512d0f5b0b86e5b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 9 Mar 2026 23:13:44 +0100 Subject: [PATCH 06/72] change required fields --- cli/medperf/asset_management/gcp_utils.py | 22 +++++++++++++------ examples/cc/chestxray/operator_cc_config.json | 8 ++----- examples/cc/rano/operator_cc_config.json | 8 ++----- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/gcp_utils.py index 88df0a803..1d641187d 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/gcp_utils.py @@ -73,14 +73,22 @@ class GCPOperatorConfig(BaseModel): account: str bucket: str machine_type: str - boot_disk_size: str - cc_type: str - min_cpu_platform: str + boot_disk_size: int # GB vm_zone: str vm_network: str - logs_poll_frequency: int + logs_poll_frequency: int = 30 # seconds gpu: bool - run_duration: str + run_duration: int = 24 # hours, only applicable for GPU workloads + + @property + def min_cpu_platform(self): + # TODO: check + return "AMD Milan" + + @property + def cc_type(self): + # TODO: check + return "SEV" @property def service_account_email(self): @@ -434,9 +442,9 @@ def run_gpu_workload( "--image-family=confidential-space-debug-preview-cgpu", f"--service-account={config.service_account_email}", "--scopes=cloud-platform", - f"--boot-disk-size={config.boot_disk_size}", + f"--boot-disk-size={config.boot_disk_size}G", "--reservation-affinity=none", - f"--max-run-duration={config.run_duration}", + f"--max-run-duration={config.run_duration}h", "--instance-termination-action=DELETE", f"--metadata={metadata}", ] diff --git a/examples/cc/chestxray/operator_cc_config.json b/examples/cc/chestxray/operator_cc_config.json index f9ca381e7..3d503e9d3 100644 --- a/examples/cc/chestxray/operator_cc_config.json +++ b/examples/cc/chestxray/operator_cc_config.json @@ -4,12 +4,8 @@ "account": "hasan.kassem@mlcommons.org", "bucket": "data-owner-bucket-medperf-cc", "machine_type": "n2d-standard-8", - "cc_type": "SEV", - "min_cpu_platform": "AMD Milan", "vm_zone": "us-west1-b", "vm_network": "medperf-brats-network", - "boot_disk_size": "100GB", - "logs_poll_frequency": 30, - "gpu": false, - "run_duration": "" + "boot_disk_size": 100, + "gpu": false } \ No newline at end of file diff --git a/examples/cc/rano/operator_cc_config.json b/examples/cc/rano/operator_cc_config.json index f9ca381e7..3d503e9d3 100644 --- a/examples/cc/rano/operator_cc_config.json +++ b/examples/cc/rano/operator_cc_config.json @@ -4,12 +4,8 @@ "account": "hasan.kassem@mlcommons.org", "bucket": "data-owner-bucket-medperf-cc", "machine_type": "n2d-standard-8", - "cc_type": "SEV", - "min_cpu_platform": "AMD Milan", "vm_zone": "us-west1-b", "vm_network": "medperf-brats-network", - "boot_disk_size": "100GB", - "logs_poll_frequency": 30, - "gpu": false, - "run_duration": "" + "boot_disk_size": 100, + "gpu": false } \ No newline at end of file From 322409e711dc39a39c3133371b4d333456d286ce Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 11 Mar 2026 00:19:39 +0100 Subject: [PATCH 07/72] update cc setup to checks only --- .../asset_management/__gcp_util_archive.py | 186 ++++++++++++++ cli/medperf/asset_management/asset_check.py | 56 +++++ .../asset_management/asset_management.py | 7 +- .../asset_management/asset_policy_manager.py | 50 ++-- .../asset_management/asset_storage_manager.py | 23 +- cli/medperf/asset_management/cc_operator.py | 127 ++++------ cli/medperf/asset_management/checks_utils.py | 233 ++++++++++++++++++ cli/medperf/asset_management/gcp_utils.py | 200 +-------------- .../asset_management/operator_check.py | 59 +++++ cli/requirements.txt | 7 +- 10 files changed, 623 insertions(+), 325 deletions(-) create mode 100644 cli/medperf/asset_management/__gcp_util_archive.py create mode 100644 cli/medperf/asset_management/asset_check.py create mode 100644 cli/medperf/asset_management/checks_utils.py create mode 100644 cli/medperf/asset_management/operator_check.py diff --git a/cli/medperf/asset_management/__gcp_util_archive.py b/cli/medperf/asset_management/__gcp_util_archive.py new file mode 100644 index 000000000..f237d55cd --- /dev/null +++ b/cli/medperf/asset_management/__gcp_util_archive.py @@ -0,0 +1,186 @@ +import logging +from typing import Union +from medperf.exceptions import ExecutionError +from medperf.utils import run_command + +from google.cloud import storage +from google.cloud.exceptions import Conflict + +from .gcp_utils import GCPOperatorConfig, GCPAssetConfig + +GCP_EXEC = "gcloud" + + +# IAM Service Account operations +def create_service_account(config: GCPOperatorConfig): + """Create service account for workload.""" + cmd = [ + GCP_EXEC, + "iam", + "service-accounts", + "create", + config.service_account_name, + ] + try: + run_command(cmd) + except ExecutionError as e: + logging.debug( + f"Service account {config.service_account_name} may already exist. Error: {e}" + ) + + +def add_service_account_iam_policy_binding( + config: GCPOperatorConfig, member: str, role: str +): + """Add IAM policy binding to service account.""" + cmd = [ + GCP_EXEC, + "iam", + "service-accounts", + "add-iam-policy-binding", + config.service_account_email, + f"--member={member}", + f"--role={role}", + ] + run_command(cmd) + + +def add_project_iam_policy_binding(config: GCPOperatorConfig, member: str, role: str): + """Add IAM policy binding to project.""" + cmd = [ + GCP_EXEC, + "projects", + "add-iam-policy-binding", + config.project_id, + f"--member={member}", + f"--role={role}", + ] + run_command(cmd) + + +# KMS operations +def create_keyring(config: GCPAssetConfig): + """Create KMS keyring.""" + cmd = [ + GCP_EXEC, + "kms", + "keyrings", + "create", + config.keyring_name, + "--location=global", + ] + try: + run_command(cmd) + except ExecutionError as e: + logging.debug(f"Keyring {config.keyring_name} may already exist. Error: {e}") + + +def create_kms_key(config: GCPAssetConfig): + """Create KMS key.""" + cmd = [ + GCP_EXEC, + "kms", + "keys", + "create", + config.key_name, + "--location=global", + f"--keyring={config.keyring_name}", + "--purpose=encryption", + ] + try: + run_command(cmd) + except ExecutionError as e: + logging.debug( + f"Key {config.key_name} may already exist in keyring {config.keyring_name}. Error: {e}" + ) + + +def add_kms_key_iam_policy_binding(config: GCPAssetConfig, member: str, role: str): + cmd = [ + GCP_EXEC, + "kms", + "keys", + "add-iam-policy-binding", + config.full_key_name, + f"--member={member}", + f"--role={role}", + ] + run_command(cmd) + + +# Workload Identity Pool operations +def create_workload_identity_pool(config: GCPAssetConfig): + """Create workload identity pool.""" + cmd = [ + GCP_EXEC, + "iam", + "workload-identity-pools", + "create", + config.wip, + "--location=global", + ] + try: + run_command(cmd) + except ExecutionError as e: + logging.debug( + f"Workload identity pool {config.wip} may already exist. Error: {e}" + ) + + +def create_workload_identity_pool_oidc_provider( + config: GCPAssetConfig, attribute_mapping: str, attribute_condition: str +): + """Create OIDC provider for workload identity pool.""" + cmd = [ + GCP_EXEC, + "iam", + "workload-identity-pools", + "providers", + "create-oidc", + "attestation-verifier", + "--location=global", + f"--workload-identity-pool={config.wip}", + "--issuer-uri=https://confidentialcomputing.googleapis.com/", + "--allowed-audiences=https://sts.googleapis.com", + f"--attribute-mapping={attribute_mapping}", + f"--attribute-condition={attribute_condition}", + ] + try: + run_command(cmd) + return + except ExecutionError as e: + logging.debug( + f"OIDC provider for workload identity pool {config.wip} may already exist. Error: {e}" + ) + + +# Storage operations +def create_storage_bucket(config: Union[GCPAssetConfig, GCPOperatorConfig]): + """Create GCS bucket.""" + client = storage.Client(project=config.project_id) + for bucket in client.list_buckets(project=config.project_id): + if bucket.name == config.bucket: + logging.debug(f"Bucket {config.bucket} already exists.") + return + + # try creating the bucket + try: + client.create_bucket(config.bucket, project=config.project_id) + except Conflict as e: + logging.debug(f"Bucket {config.bucket} already exists. Conflict: {e}") + + +def add_bucket_iam_policy_binding( + config: Union[GCPAssetConfig, GCPOperatorConfig], member: str, role: str +): + """Add IAM policy binding to GCS bucket.""" + cmd = [ + GCP_EXEC, + "storage", + "buckets", + "add-iam-policy-binding", + f"gs://{config.bucket}", + f"--member={member}", + f"--role={role}", + ] + run_command(cmd) diff --git a/cli/medperf/asset_management/asset_check.py b/cli/medperf/asset_management/asset_check.py new file mode 100644 index 000000000..5b6f83a6e --- /dev/null +++ b/cli/medperf/asset_management/asset_check.py @@ -0,0 +1,56 @@ +import google.auth +from medperf.asset_management.checks_utils import ( + check_user_role_on_bucket, + check_user_role_on_kms_key, + check_user_role_on_wip, +) +from google.auth.credentials import AnonymousCredentials + + +def verify_asset_owner_setup(bucket_name, kms_key_resource, wip_resource): + base_creds, _ = google.auth.default() + result = check_user_role_on_bucket( + "user", + base_creds, + bucket_name, + "roles/storage.objectAdmin", + ) + if result: + return False, result + + result = check_user_role_on_kms_key( + base_creds, + kms_key_resource, + "roles/cloudkms.cryptoKeyEncrypter", + ) + + if result: + return False, result + + result = check_user_role_on_kms_key( + base_creds, + kms_key_resource, + "roles/cloudkms.admin", + ) + + if result: + return False, result + + anon_creds = AnonymousCredentials() + + result = check_user_role_on_bucket( + "anonymous user", + anon_creds, + bucket_name, + "roles/storage.objectViewer", + ) + if result: + return False, result + + result = check_user_role_on_wip( + base_creds, + wip_resource, + "roles/iam.workloadIdentityPoolAdmin", + ) + + return True, "" diff --git a/cli/medperf/asset_management/asset_management.py b/cli/medperf/asset_management/asset_management.py index 2e2092af3..b6a7fe09e 100644 --- a/cli/medperf/asset_management/asset_management.py +++ b/cli/medperf/asset_management/asset_management.py @@ -87,16 +87,17 @@ def setup_model_for_cc(model: Model): def __setup_asset_for_cc( cc_config: dict, cc_policy: dict, asset_path: str, encryption_key_file: str ): - # asset storage setup asset_storage_manager = AssetStorageManager( cc_config, asset_path, encryption_key_file ) + asset_policy_manager = AssetPolicyManager(cc_config, encryption_key_file) asset_storage_manager.setup() + asset_policy_manager.setup() + + # storage asset_storage_manager.store_asset() # policy setup - asset_policy_manager = AssetPolicyManager(cc_config, encryption_key_file) - asset_policy_manager.setup() asset_policy_manager.setup_policy(cc_policy) diff --git a/cli/medperf/asset_management/asset_policy_manager.py b/cli/medperf/asset_management/asset_policy_manager.py index 0d835580b..16f9e8e6c 100644 --- a/cli/medperf/asset_management/asset_policy_manager.py +++ b/cli/medperf/asset_management/asset_policy_manager.py @@ -1,44 +1,35 @@ from medperf.utils import generate_tmp_path -from medperf.asset_management import gcp_utils +from medperf.asset_management.gcp_utils import ( + GCPAssetConfig, + CCWorkloadID, + upload_file_to_gcs, + encrypt_with_kms_key, + set_kms_iam_policy, + update_workload_identity_pool_oidc_provider, +) class AssetPolicyManager: def __init__(self, config: dict, encryption_key_file: str): - self.config = gcp_utils.GCPAssetConfig(**config) + self.config = GCPAssetConfig(**config) self.encryption_key_file = encryption_key_file - def __create_keyring(self): - gcp_utils.create_keyring(self.config) - - def __create_key(self): - gcp_utils.create_kms_key(self.config) - - def __add_key_iam_binding(self): - gcp_utils.add_kms_key_iam_policy_binding( - self.config, - f"user:{self.config.account}", - "roles/cloudkms.cryptoKeyEncrypter", - ) - - def __create_workload_identity_pool(self): - gcp_utils.create_workload_identity_pool(self.config) - def __encrypt_key(self): tmp_encrypted_key_path = generate_tmp_path() - gcp_utils.encrypt_with_kms_key( + encrypt_with_kms_key( self.config, self.encryption_key_file, tmp_encrypted_key_path ) return tmp_encrypted_key_path def __upload_encrypted_key(self, tmp_encrypted_key_path): - gcp_utils.upload_file_to_gcs( + upload_file_to_gcs( self.config, tmp_encrypted_key_path, f"gs://{self.config.bucket}/{self.config.encrypted_key_bucket_file}", ) - def __create_wip_oidc_provider(self, policy: dict[str, str]): + def __update_wip_oidc_provider(self, policy: dict[str, str]): # IMPORTANT: https://docs.cloud.google.com/confidential-computing/ # confidential-space/docs/create-grant-access-confidential-resources#attestation-assertions google_subject_attr = ( @@ -74,13 +65,11 @@ def __create_wip_oidc_provider(self, policy: dict[str, str]): ) attribute_condition += f" && {gpu_cc_mode_condition}" - gcp_utils.create_workload_identity_pool_oidc_provider( + update_workload_identity_pool_oidc_provider( self.config, attribute_mapping, attribute_condition ) - def __bind_kms_decrypter_role( - self, permitted_workloads: list[gcp_utils.CCWorkloadID] - ): + def __bind_kms_decrypter_role(self, permitted_workloads: list[CCWorkloadID]): principal_set = ( f"principalSet://iam.googleapis.com/projects/{self.config.project_number}/" f"locations/global/workloadIdentityPools/{self.config.wip}/attribute.workload_uid/" @@ -90,22 +79,19 @@ def __bind_kms_decrypter_role( for workload in permitted_workloads: principal_set_list.append(principal_set + workload.id) - gcp_utils.set_kms_iam_policy( + set_kms_iam_policy( self.config, principal_set_list, "roles/cloudkms.cryptoKeyDecrypter", ) def setup(self): - self.__create_keyring() - self.__create_key() - self.__add_key_iam_binding() - self.__create_workload_identity_pool() + pass def setup_policy(self, policy: dict[str, str]): tmp_encrypted_key_path = self.__encrypt_key() self.__upload_encrypted_key(tmp_encrypted_key_path) - self.__create_wip_oidc_provider(policy) + self.__update_wip_oidc_provider(policy) - def configure_policy(self, permitted_workloads: list[gcp_utils.CCWorkloadID]): + def configure_policy(self, permitted_workloads: list[CCWorkloadID]): self.__bind_kms_decrypter_role(permitted_workloads) diff --git a/cli/medperf/asset_management/asset_storage_manager.py b/cli/medperf/asset_management/asset_storage_manager.py index bc8ff1d9e..0c2aadc96 100644 --- a/cli/medperf/asset_management/asset_storage_manager.py +++ b/cli/medperf/asset_management/asset_storage_manager.py @@ -1,18 +1,17 @@ from medperf.utils import generate_tmp_path, get_file_hash from medperf.encryption import SymmetricEncryption -from medperf.asset_management import gcp_utils +from medperf.asset_management.gcp_utils import GCPAssetConfig, upload_file_to_gcs +from medperf.asset_management.asset_check import verify_asset_owner_setup +from medperf.exceptions import MedperfException class AssetStorageManager: def __init__(self, config: dict, asset_path: str, encryption_key_file: str): - self.config = gcp_utils.GCPAssetConfig(**config) + self.config = GCPAssetConfig(**config) self.asset_path = asset_path self.encryption_key_file = encryption_key_file - def __create_bucket(self): - gcp_utils.create_storage_bucket(self.config) - def __encrypt_asset(self): tmp_encrypted_asset_path = generate_tmp_path() SymmetricEncryption().encrypt_file( @@ -22,20 +21,18 @@ def __encrypt_asset(self): return tmp_encrypted_asset_path, asset_hash def __upload_encrypted_asset(self, tmp_encrypted_asset_path): - gcp_utils.upload_file_to_gcs( + upload_file_to_gcs( self.config, tmp_encrypted_asset_path, f"gs://{self.config.bucket}/{self.config.encrypted_asset_bucket_file}", ) - def __grant_bucket_public_read_access(self): - gcp_utils.add_bucket_iam_policy_binding( - self.config, "allUsers", "roles/storage.objectViewer" - ) - def setup(self): - self.__create_bucket() - self.__grant_bucket_public_read_access() + success, message = verify_asset_owner_setup( + self.config.bucket, self.config.full_key_name, self.config.full_wip_name + ) + if not success: + raise MedperfException(f"Asset owner setup verification failed: {message}") def store_asset(self): tmp_encrypted_asset_path, asset_hash = self.__encrypt_asset() diff --git a/cli/medperf/asset_management/cc_operator.py b/cli/medperf/asset_management/cc_operator.py index 6fc2cb793..4a740353a 100644 --- a/cli/medperf/asset_management/cc_operator.py +++ b/cli/medperf/asset_management/cc_operator.py @@ -1,95 +1,37 @@ import json -from medperf.asset_management import gcp_utils +from medperf.asset_management.gcp_utils import ( + GCPOperatorConfig, + CCWorkloadID, + download_file_from_gcs, + run_workload, + run_gpu_workload, + wait_for_workload_completion, +) +from medperf.asset_management.operator_check import verify_operator_setup +from medperf.exceptions import MedperfException from medperf.utils import generate_tmp_path, untar from medperf.encryption import SymmetricEncryption, AsymmetricEncryption -def create_service_account(config: gcp_utils.GCPOperatorConfig): - gcp_utils.create_service_account(config) - - -def allow_operator_to_use_service_account(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_service_account_iam_policy_binding( - config, - f"user:{config.account}", - "roles/iam.serviceAccountUser", - ) - - -def grant_confidential_computing_workload_user(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_project_iam_policy_binding( - config, - f"serviceAccount:{config.service_account_email}", - "roles/confidentialcomputing.workloadUser", - ) - - -def grant_logging_log_writer(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_project_iam_policy_binding( - config, - f"serviceAccount:{config.service_account_email}", - "roles/logging.logWriter", - ) - - -def run_workload( - config: gcp_utils.GCPOperatorConfig, - docker_image: str, - env_vars: dict, - workload: gcp_utils.CCWorkloadID, -): - # Build metadata string - metadata_parts = [ - f"tee-image-reference={docker_image}", - "tee-container-log-redirect=true", - ] - - # Add environment variables - for key, value in env_vars.items(): - metadata_parts.append(f"tee-env-{key}={value}") - - if config.gpu: - metadata_parts.append("tee-install-gpu-driver=true") - - metadata = "^~^" + "~".join(metadata_parts) - - if config.gpu: - gcp_utils.run_gpu_workload(config, workload, metadata) - else: - gcp_utils.run_workload(config, workload, metadata) - - -def grant_bucket_read_access(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_bucket_iam_policy_binding( - config, f"user:{config.account}", "roles/storage.objectViewer" - ) - - -def grant_bucket_write_access(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_bucket_iam_policy_binding( - config, - f"serviceAccount:{config.service_account_email}", - "roles/storage.objectAdmin", - ) - - class OperatorManager: def __init__(self, config: dict): - self.config = gcp_utils.GCPOperatorConfig(**config) + self.config = GCPOperatorConfig(**config) def setup(self): """Set up complete operator infrastructure""" - create_service_account(self.config) - allow_operator_to_use_service_account(self.config) - grant_confidential_computing_workload_user(self.config) - grant_logging_log_writer(self.config) - grant_bucket_read_access(self.config) - grant_bucket_write_access(self.config) + success, message = verify_operator_setup( + self.config.service_account_email, + self.config.project_id, + self.config.bucket, + ) + + if not success: + raise MedperfException(f"Operator setup verification failed: {message}") def run_workload( self, docker_image: str, - workload: gcp_utils.CCWorkloadID, + workload: CCWorkloadID, dataset_cc_config: dict, model_cc_config: dict, result_collector_public_key: str, @@ -115,14 +57,31 @@ def run_workload( "RESULT_COLLECTOR": result_collector_public_key, "EXPECTED_RESULT_COLLECTOR_HASH": workload.result_collector_hash, } - run_workload(self.config, docker_image, env_vars, workload) + metadata_parts = [ + f"tee-image-reference={docker_image}", + "tee-container-log-redirect=true", + ] + + # Add environment variables + for key, value in env_vars.items(): + metadata_parts.append(f"tee-env-{key}={value}") + + if self.config.gpu: + metadata_parts.append("tee-install-gpu-driver=true") + + metadata = "^~^" + "~".join(metadata_parts) + + if self.config.gpu: + run_gpu_workload(self.config, workload, metadata) + else: + run_workload(self.config, workload, metadata) - def wait_for_workload_completion(self, workload: gcp_utils.CCWorkloadID): - gcp_utils.wait_for_workload_completion(self.config, workload) + def wait_for_workload_completion(self, workload: CCWorkloadID): + wait_for_workload_completion(self.config, workload) def download_results( self, - workload: gcp_utils.CCWorkloadID, + workload: CCWorkloadID, private_key_bytes: bytes, results_path: str, ): @@ -130,12 +89,12 @@ def download_results( encrypted_results_path = generate_tmp_path() key_path = generate_tmp_path() - gcp_utils.download_file_from_gcs( + download_file_from_gcs( self.config, f"gs://{self.config.bucket}/{workload.results_path}", encrypted_results_path, ) - gcp_utils.download_file_from_gcs( + download_file_from_gcs( self.config, f"gs://{self.config.bucket}/{workload.results_encryption_key_path}", key_path, diff --git a/cli/medperf/asset_management/checks_utils.py b/cli/medperf/asset_management/checks_utils.py new file mode 100644 index 000000000..7ef5ae8ad --- /dev/null +++ b/cli/medperf/asset_management/checks_utils.py @@ -0,0 +1,233 @@ +import logging +from google.auth import impersonated_credentials +import googleapiclient.discovery + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +def get_testable_permissions(resource): + if resource.startswith("//cloudresourcemanager"): + # don't find testable permissions for cloudresourcemanager + # Since they are a lot. + return { + "confidentialcomputing.challenges.create", + "confidentialcomputing.challenges.verify", + "confidentialcomputing.locations.get", + "confidentialcomputing.locations.list", + "logging.logEntries.create", + "logging.logEntries.route", + } + if "workloadIdentityPools" in resource: + # doesn't seem to have an api for this + return { + "iam.googleapis.com/workloadIdentityPoolProviders.update", + "iam.googleapis.com/workloadIdentityPoolProviders.get", + "iam.googleapis.com/workloadIdentityPools.get", + } + + iam = googleapiclient.discovery.build("iam", "v1") + resp = ( + iam.permissions() + .queryTestablePermissions(body={"fullResourceName": resource, "pageSize": 1000}) + .execute() + ) + return set(p["name"] for p in resp.get("permissions", [])) + + +def get_role_permissions(role_name: str, resource: str): + service = googleapiclient.discovery.build("iam", "v1") + role = service.roles().get(name=role_name).execute() + permissions = role.get("includedPermissions", []) + testable = get_testable_permissions(resource) + return list(testable.intersection(permissions)) + + +def impersonate_service_account(base_creds, sa_email): + logging.debug(f"Impersonating service account: {sa_email}") + try: + return impersonated_credentials.Credentials( + source_credentials=base_creds, + target_principal=sa_email, + target_scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + except Exception as e: + logging.debug(f"Failed to impersonate service account {sa_email}: {e}") + return None + + +# --------------------------------------------------------------------------- +# User roles +# --------------------------------------------------------------------------- +def check_user_role_on_service_account(base_creds, sa_email, role): + permissions = get_role_permissions( + role, "//iam.googleapis.com/projects/_/serviceAccounts/_" + ) + logging.debug(f"Checking if user has {role} role on {sa_email}") + try: + iam = googleapiclient.discovery.build( + "iam", "v1", credentials=base_creds, cache_discovery=False + ) + granted = ( + iam.projects() + .serviceAccounts() + .testIamPermissions( + resource=f"projects/-/serviceAccounts/{sa_email}", + body={"permissions": permissions}, + ) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return f"(Role {role}) User missing permissions: {missing} on service account: {sa_email}" + return None + except Exception as e: + logging.debug(f"check_user_role_on_service_account exception: {e}") + return f"Failed to verify user role on service account: {sa_email}" + + +def check_user_role_on_bucket(user_str, creds, bucket_name, role): + permissions = get_role_permissions( + role, "//storage.googleapis.com/projects/_/buckets/_" + ) + logging.debug(f"Checking if {user_str} has {role} role on bucket: {bucket_name}") + try: + iam = googleapiclient.discovery.build( + "storage", "v1", credentials=creds, cache_discovery=False + ) + granted = ( + iam.buckets() + .testIamPermissions(bucket=bucket_name, permissions=permissions) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return f"(Role {role}) {user_str} missing permissions: {missing} on bucket: {bucket_name}" + return None + except Exception as e: + logging.debug(f"check_user_role_on_bucket exception: {e}") + return f"Failed to verify {user_str} role on bucket: {bucket_name}" + + +# --------------------------------------------------------------------------- +# Service Account roles +# --------------------------------------------------------------------------- + + +def check_sa_roles_for_project(sa_creds, project_id, role): + permissions = get_role_permissions( + role, "//cloudresourcemanager.googleapis.com/projects/_" + ) + logging.debug(f"Checking service account project permissions: {project_id}") + try: + crm = googleapiclient.discovery.build( + "cloudresourcemanager", "v1", credentials=sa_creds, cache_discovery=False + ) + granted = ( + crm.projects() + .testIamPermissions(resource=project_id, body={"permissions": permissions}) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return ( + f"(Role {role}) Service account missing permissions: " + f"{missing} on project: {project_id}" + ) + return None + except Exception as e: + logging.debug(f"check_sa_project_permissions exception: {e}") + return f"Failed to verify service account role on project: {project_id}" + + +# --------------------------------------------------------------------------- +# KMS roles +# --------------------------------------------------------------------------- + + +def check_user_role_on_kms_key(base_creds, kms_key_resource, role): + logging.debug(f"Checking user role {role} on KMS key {kms_key_resource}") + + try: + kms = googleapiclient.discovery.build( + "cloudkms", "v1", credentials=base_creds, cache_discovery=False + ) + + permissions = get_role_permissions( + role, + "//cloudkms.googleapis.com/projects/_/locations/_/keyRings/_/cryptoKeys/_", + ) + + granted = ( + kms.projects() + .locations() + .keyRings() + .cryptoKeys() + .testIamPermissions( + resource=kms_key_resource, + body={"permissions": permissions}, + ) + .execute() + ) + + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + + if missing: + return f"(Role {role}) User missing permissions: {missing} on KMS key {kms_key_resource}" + + return None + + except Exception as e: + logging.debug(f"KMS permission check failed: {e}") + return f"Failed verifying user roles on KMS key {kms_key_resource}" + + +# --------------------------------------------------------------------------- +# WIP roles +# --------------------------------------------------------------------------- + + +def check_user_role_on_wip(creds, wip, role): + + logging.debug(f"Checking user role {role} on WIP {wip}") + + try: + iam = googleapiclient.discovery.build( + "iam", "v1", credentials=creds, cache_discovery=False + ) + + permissions = get_role_permissions( + role, "//iam.googleapis.com/projects/_/locations/_/workloadIdentityPools/_" + ) + + granted = ( + iam.projects() + .locations() + .workloadIdentityPools() + .testIamPermissions( + resource=wip, + body={"permissions": permissions}, + ) + .execute() + ) + + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + + if missing: + return f"(Role {role}) User missing permissions: {missing} on WIP {wip}" + + return None + + except Exception as e: + logging.debug(f"WIP permission check failed: {e}") + return f"Failed verifying user roles on WIP {wip}" diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/gcp_utils.py index 1d641187d..844b9abac 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/gcp_utils.py @@ -2,12 +2,9 @@ import logging from typing import Union -from medperf.exceptions import ExecutionError from medperf.utils import run_command from google.cloud import kms from google.iam.v1 import policy_pb2 -from google.cloud import storage -from google.cloud.exceptions import Conflict from google.cloud import compute_v1 import time from colorama import Fore, Style @@ -114,116 +111,20 @@ def full_key_name(self) -> str: ) @property - def full_wip_name(self) -> str: + def full_wip_provider_name(self) -> str: return ( f"projects/{self.project_number}/locations/global/" f"workloadIdentityPools/{self.wip}/providers/attestation-verifier" ) - -# IAM Service Account operations -def create_service_account(config: GCPOperatorConfig): - """Create service account for workload.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "iam", - "service-accounts", - "create", - config.service_account_name, - ] - try: - run_command(cmd) - except ExecutionError as e: - logging.debug( - f"Service account {config.service_account_name} may already exist. Error: {e}" - ) - - -def add_service_account_iam_policy_binding( - config: GCPOperatorConfig, member: str, role: str -): - """Add IAM policy binding to service account.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "iam", - "service-accounts", - "add-iam-policy-binding", - config.service_account_email, - f"--member={member}", - f"--role={role}", - ] - run_command(cmd) - - -def add_project_iam_policy_binding(config: GCPOperatorConfig, member: str, role: str): - """Add IAM policy binding to project.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "projects", - "add-iam-policy-binding", - config.project_id, - f"--member={member}", - f"--role={role}", - ] - run_command(cmd) - - -# KMS operations -def create_keyring(config: GCPAssetConfig): - """Create KMS keyring.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "kms", - "keyrings", - "create", - config.keyring_name, - "--location=global", - ] - try: - run_command(cmd) - except ExecutionError as e: - logging.debug(f"Keyring {config.keyring_name} may already exist. Error: {e}") - - -def create_kms_key(config: GCPAssetConfig): - """Create KMS key.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "kms", - "keys", - "create", - config.key_name, - "--location=global", - f"--keyring={config.keyring_name}", - "--purpose=encryption", - ] - try: - run_command(cmd) - except ExecutionError as e: - logging.debug( - f"Key {config.key_name} may already exist in keyring {config.keyring_name}. Error: {e}" + @property + def full_wip_name(self) -> str: + return ( + f"projects/{self.project_number}/locations/global/" + f"workloadIdentityPools/{self.wip}" ) -def add_kms_key_iam_policy_binding(config: GCPAssetConfig, member: str, role: str): - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "kms", - "keys", - "add-iam-policy-binding", - config.full_key_name, - f"--member={member}", - f"--role={role}", - ] - run_command(cmd) - - def set_kms_iam_policy(config: GCPAssetConfig, members: list[str], role: str): client = kms.KeyManagementServiceClient() # Get current policy @@ -249,7 +150,6 @@ def encrypt_with_kms_key( """Encrypt file using KMS key.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "kms", "encrypt", f"--ciphertext-file={ciphertext_file}", @@ -259,56 +159,11 @@ def encrypt_with_kms_key( run_command(cmd) -# Workload Identity Pool operations -def create_workload_identity_pool(config: GCPAssetConfig): - """Create workload identity pool.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "iam", - "workload-identity-pools", - "create", - config.wip, - "--location=global", - ] - try: - run_command(cmd) - except ExecutionError as e: - logging.debug( - f"Workload identity pool {config.wip} may already exist. Error: {e}" - ) - - -def create_workload_identity_pool_oidc_provider( +def update_workload_identity_pool_oidc_provider( config: GCPAssetConfig, attribute_mapping: str, attribute_condition: str ): - """Create OIDC provider for workload identity pool.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", - "iam", - "workload-identity-pools", - "providers", - "create-oidc", - "attestation-verifier", - "--location=global", - f"--workload-identity-pool={config.wip}", - "--issuer-uri=https://confidentialcomputing.googleapis.com/", - "--allowed-audiences=https://sts.googleapis.com", - f"--attribute-mapping={attribute_mapping}", - f"--attribute-condition={attribute_condition}", - ] - try: - run_command(cmd) - return - except ExecutionError as e: - logging.debug( - f"OIDC provider for workload identity pool {config.wip} may already exist. Error: {e}" - ) - # try updating the provider if it already exists - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", "iam", "workload-identity-pools", "providers", @@ -326,29 +181,12 @@ def create_workload_identity_pool_oidc_provider( ) -# Storage operations -def create_storage_bucket(config: Union[GCPAssetConfig, GCPOperatorConfig]): - """Create GCS bucket.""" - client = storage.Client(project=config.project_id) - for bucket in client.list_buckets(project=config.project_id): - if bucket.name == config.bucket: - logging.debug(f"Bucket {config.bucket} already exists.") - return - - # try creating the bucket - try: - client.create_bucket(config.bucket, project=config.project_id) - except Conflict as e: - logging.debug(f"Bucket {config.bucket} already exists. Conflict: {e}") - - def upload_file_to_gcs( config: Union[GCPAssetConfig, GCPOperatorConfig], local_file: str, gcs_path: str ): """Upload file to Google Cloud Storage.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "storage", "cp", local_file, @@ -363,7 +201,6 @@ def download_file_from_gcs( """Download file from Google Cloud Storage.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "storage", "cp", gcs_path, @@ -372,23 +209,6 @@ def download_file_from_gcs( run_command(cmd) -def add_bucket_iam_policy_binding( - config: Union[GCPAssetConfig, GCPOperatorConfig], member: str, role: str -): - """Add IAM policy binding to GCS bucket.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "storage", - "buckets", - "add-iam-policy-binding", - f"gs://{config.bucket}", - f"--member={member}", - f"--role={role}", - ] - run_command(cmd) - - # run def run_workload( config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: str @@ -396,7 +216,6 @@ def run_workload( # note: machine type and cc type must conform somehow cmd = [ GCP_EXEC, - f"--project={config.project_id}", "compute", "instances", "create", @@ -404,7 +223,7 @@ def run_workload( f"--confidential-compute-type={config.cc_type}", "--shielded-secure-boot", "--scopes=cloud-platform", - f"--boot-disk-size={config.boot_disk_size}", + f"--boot-disk-size={config.boot_disk_size}G", f"--zone={config.vm_zone}", f"--network={config.vm_network}", "--maintenance-policy=MIGRATE", @@ -427,7 +246,6 @@ def run_gpu_workload( cmd = [ GCP_EXEC, - f"--project={config.project_id}", "beta", "compute", "instance-templates", @@ -458,7 +276,6 @@ def run_gpu_workload( cmd = [ GCP_EXEC, - f"--project={config.project_id}", "compute", "instance-groups", "managed", @@ -473,7 +290,6 @@ def run_gpu_workload( cmd = [ GCP_EXEC, - f"--project={config.project_id}", "compute", "instance-groups", "managed", diff --git a/cli/medperf/asset_management/operator_check.py b/cli/medperf/asset_management/operator_check.py new file mode 100644 index 000000000..dc30ab209 --- /dev/null +++ b/cli/medperf/asset_management/operator_check.py @@ -0,0 +1,59 @@ +import google.auth +from medperf.asset_management.checks_utils import ( + check_user_role_on_bucket, + check_sa_roles_for_project, + check_user_role_on_service_account, + impersonate_service_account, +) + + +def verify_operator_setup(sa_email, project_id, bucket_name): + base_creds, _ = google.auth.default() + + result = check_user_role_on_service_account( + base_creds, + sa_email, + "roles/iam.serviceAccountUser", + ) + if result: + return False, result + + result = check_user_role_on_bucket( + "user", + base_creds, + bucket_name, + "roles/storage.objectViewer", + ) + if result: + return False, result + + sa_creds = impersonate_service_account(base_creds, sa_email) + if not sa_creds: + return False, f"Failed to impersonate service account: {sa_email}" + + result = check_sa_roles_for_project( + sa_creds, + project_id, + "roles/confidentialcomputing.workloadUser", + ) + if result: + return False, result + + result = check_sa_roles_for_project( + sa_creds, + project_id, + "roles/logging.logWriter", + ) + if result: + return False, result + + result = check_user_role_on_bucket( + "service account", + sa_creds, + bucket_name, + "roles/storage.objectAdmin", + ) + if result: + return False, result + + return True, "" diff --git a/cli/requirements.txt b/cli/requirements.txt index 56531ad26..506f82877 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -28,8 +28,13 @@ fastapi==0.111.1 fastapi-login==1.10.2 cryptography==46.0.3 click==8.1.8 + +google-auth +google-cloud-storage +google-api-python-client +google-auth-httplib2 + google-cloud-iam google-cloud-kms -google-cloud-storage google-cloud-resource-manager google-cloud-compute \ No newline at end of file From c0c5085d9f0c9c367da31b8d3870dde5dbcb936d Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 11 Mar 2026 00:26:17 +0100 Subject: [PATCH 08/72] add webui for cc --- cli/medperf/web_ui/datasets/routes.py | 40 ++++++++++++ cli/medperf/web_ui/models/routes.py | 38 +++++++++++ cli/medperf/web_ui/settings.py | 47 ++++++++++++++ cli/medperf/web_ui/static/js/cc.js | 65 +++++++++++++++++++ cli/medperf/web_ui/static/js/cc_operator.js | 65 +++++++++++++++++++ .../templates/dataset/dataset_detail.html | 6 ++ .../templates/macros/cc_asset_macro.html | 41 ++++++++++++ .../templates/macros/cc_operator_macro.html | 43 ++++++++++++ .../web_ui/templates/model/model_detail.html | 6 ++ cli/medperf/web_ui/templates/settings.html | 6 ++ 10 files changed, 357 insertions(+) create mode 100644 cli/medperf/web_ui/static/js/cc.js create mode 100644 cli/medperf/web_ui/static/js/cc_operator.js create mode 100644 cli/medperf/web_ui/templates/macros/cc_asset_macro.html create mode 100644 cli/medperf/web_ui/templates/macros/cc_operator_macro.html diff --git a/cli/medperf/web_ui/datasets/routes.py b/cli/medperf/web_ui/datasets/routes.py index 45696ee00..1cc190157 100644 --- a/cli/medperf/web_ui/datasets/routes.py +++ b/cli/medperf/web_ui/datasets/routes.py @@ -17,6 +17,7 @@ from medperf.commands.execution.create import BenchmarkExecution from medperf.commands.execution.submit import ResultSubmission from medperf.commands.execution.utils import filter_latest_executions +from medperf.commands.cc.dataset_configure_for_cc import DatasetConfigureForCC from medperf.entities.cube import Cube from medperf.entities.dataset import Dataset from medperf.entities.benchmark import Benchmark @@ -133,6 +134,10 @@ def dataset_detail_ui( # noqa model.result["results"] = result.read_results() report_exists = os.path.exists(dataset.report_path) + + cc_config_defaults = dataset.get_cc_config() + cc_configured = dataset.is_cc_configured() + return templates.TemplateResponse( "dataset/dataset_detail.html", { @@ -147,6 +152,8 @@ def dataset_detail_ui( # noqa "approved_benchmarks": approved_benchmarks, "is_owner": is_owner, "report_exists": report_exists, + "cc_config_defaults": cc_config_defaults, + "cc_configured": cc_configured, }, ) @@ -475,3 +482,36 @@ def import_dataset( ) return return_response + + +@router.post("/edit_cc_config", response_class=JSONResponse) +def edit_cc_config( + entity_id: int = Form(...), + require_cc: bool = Form(...), + project_id: str = Form(""), + project_number: str = Form(""), + account: str = Form(""), + bucket: str = Form(""), + keyring_name: str = Form(""), + key_name: str = Form(""), + wip: str = Form(""), + current_user: bool = Depends(check_user_api), +): + args = { + "project_id": project_id, + "project_number": project_number, + "account": account, + "bucket": bucket, + "keyring_name": keyring_name, + "key_name": key_name, + "wip": wip, + } + if not require_cc: + args = {} + + try: + DatasetConfigureForCC.run(entity_id, args, {}) + return {"status": "success", "error": ""} + except Exception as exp: + logger.exception(exp) + return {"status": "failed", "error": str(exp)} diff --git a/cli/medperf/web_ui/models/routes.py b/cli/medperf/web_ui/models/routes.py index 74e1eb5c9..63be263d1 100644 --- a/cli/medperf/web_ui/models/routes.py +++ b/cli/medperf/web_ui/models/routes.py @@ -8,6 +8,7 @@ from medperf.entities.benchmark import Benchmark from medperf.commands.model.associate import AssociateModel from medperf.commands.mlcube.utils import check_access_to_container +from medperf.commands.cc.model_configure_for_cc import ModelConfigureForCC import medperf.config as config from medperf.web_ui.common import ( check_user_api, @@ -74,6 +75,8 @@ def model_detail_ui( else: container_object = model.container_obj + cc_config_defaults = model.get_cc_config() + cc_configured = model.is_cc_configured() return templates.TemplateResponse( "model/model_detail.html", { @@ -86,6 +89,8 @@ def model_detail_ui( "is_owner": is_owner, "benchmarks_associations": benchmark_associations, # "benchmarks": benchmarks, + "cc_config_defaults": cc_config_defaults, + "cc_configured": cc_configured, }, ) @@ -117,3 +122,36 @@ def associate( url=f"/models/ui/display/{model_id}", ) return return_response + + +@router.post("/edit_cc_config", response_class=JSONResponse) +def edit_cc_config( + entity_id: int = Form(...), + require_cc: bool = Form(...), + project_id: str = Form(""), + project_number: str = Form(""), + account: str = Form(""), + bucket: str = Form(""), + keyring_name: str = Form(""), + key_name: str = Form(""), + wip: str = Form(""), + current_user: bool = Depends(check_user_api), +): + args = { + "project_id": project_id, + "project_number": project_number, + "account": account, + "bucket": bucket, + "keyring_name": keyring_name, + "key_name": key_name, + "wip": wip, + } + if not require_cc: + args = {} + + try: + ModelConfigureForCC.run(entity_id, args, {}) + return {"status": "success", "error": ""} + except Exception as exp: + logger.exception(exp) + return {"status": "failed", "error": str(exp)} diff --git a/cli/medperf/web_ui/settings.py b/cli/medperf/web_ui/settings.py index d188833ed..67082d40a 100644 --- a/cli/medperf/web_ui/settings.py +++ b/cli/medperf/web_ui/settings.py @@ -6,8 +6,10 @@ from medperf.commands.certificate.delete_client_certificate import DeleteCertificate from medperf.commands.certificate.submit import SubmitCertificate from medperf.commands.certificate.utils import current_user_certificate_status +from medperf.account_management.account_management import get_medperf_user_object from medperf.commands.utils import set_profile_args from medperf.config_management.config_management import read_config, write_config +from medperf.commands.cc.setup_cc_operator import SetupCCOperator from medperf.entities.ca import CA from medperf.exceptions import InvalidArgumentError from medperf.utils import make_pretty_dict @@ -33,11 +35,18 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)): cas = None certificate_status = None + cc_config_defaults = {} + cc_configured = False + if is_logged_in(): cas = CA.all() cas = {c.id: c.name for c in cas} certificate_status = current_user_certificate_status() + user = get_medperf_user_object() + cc_config_defaults = user.get_cc_config() + cc_configured = user.is_cc_configured() + return templates.TemplateResponse( "settings.html", { @@ -49,6 +58,8 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)): "default_ca": config.certificate_authority_id, "default_fingerprint": config.certificate_authority_fingerprint, "certificate_status": certificate_status, + "cc_config_defaults": cc_config_defaults, + "cc_configured": cc_configured, }, ) @@ -206,3 +217,39 @@ def submit_certificate( return_response=return_response, ) return return_response + + +@router.post("/edit_cc_operator", response_class=JSONResponse) +def edit_cc_operator( + require_cc: bool = Form(...), + project_id: str = Form(""), + service_account_name: str = Form(""), + machine_type: str = Form(""), + account: str = Form(""), + bucket: str = Form(""), + vm_zone: str = Form(""), + vm_network: str = Form(""), + boot_disk_size: str = Form(""), + gpus: str = Form(""), + current_user: bool = Depends(check_user_api), +): + args = { + "project_id": project_id, + "service_account_name": service_account_name, + "machine_type": machine_type, + "account": account, + "bucket": bucket, + "vm_zone": vm_zone, + "vm_network": vm_network, + "boot_disk_size": boot_disk_size, + "gpus": gpus, + } + if not require_cc: + args = {} + + try: + SetupCCOperator.run(args) + return {"status": "success", "error": ""} + except Exception as exp: + logger.exception(exp) + return {"status": "failed", "error": str(exp)} diff --git a/cli/medperf/web_ui/static/js/cc.js b/cli/medperf/web_ui/static/js/cc.js new file mode 100644 index 000000000..0a343a03f --- /dev/null +++ b/cli/medperf/web_ui/static/js/cc.js @@ -0,0 +1,65 @@ +const fields = [ + "cc-project_id", + "cc-project_number", + "cc-account", + "cc-bucket", + "cc-keyring_name", + "cc-key_name", + "cc-wip", + "require-cc" +]; + +function checkForCCEditChanges() { + // const hasChanges = fields.some(field => { + // return $(`#${field}`).val() !== window.defaultCCConfig[field]; + // }); + // TODO + hasChanges = true; + $('#apply-cc-asset-btn').prop('disabled', !hasChanges); +} + +function editCCConfig(editCCConfigBtn) { + const formData = new FormData($("#edit-cc-asset-form")[0]); + const entityId = editCCConfigBtn.getAttribute("data-entity-id"); + const entityType = editCCConfigBtn.getAttribute("data-entity-type"); + formData.append("entity_id", entityId); + const url = `/${entityType}s/edit_cc_config`; + + disableElements("#edit-cc-asset-form input, #edit-cc-asset-form button"); + disableElements(".card button"); + + ajaxRequest( + url, + "POST", + formData, + (response) => { + if (response.status === "success"){ + showReloadModal({ + title: "CC Configuration Edited Successfully", + seconds: 3 + }); + } + else { + showErrorModal("Failed to Edit CC Configuration", response); + } + }, + "Error editing CC Configuration:" + ); +} + + +$(document).ready(() => { + const checkbox = $("#require-cc"); + checkbox.on("change", () => { + $("#edit-cc-asset-fields").toggle(checkbox.is(":checked")); + }); + $("#edit-cc-asset-fields").toggle(checkbox.is(":checked")); + + fields.forEach(field => $(`#${field}`).on('input', checkForCCEditChanges)); + checkForCCEditChanges(); + + $("#apply-cc-asset-btn").on("click", (e) => { + showConfirmModal(e.currentTarget, editCCConfig, "edit CC configuration?"); + }); + +}); diff --git a/cli/medperf/web_ui/static/js/cc_operator.js b/cli/medperf/web_ui/static/js/cc_operator.js new file mode 100644 index 000000000..27efa937e --- /dev/null +++ b/cli/medperf/web_ui/static/js/cc_operator.js @@ -0,0 +1,65 @@ +const fields = [ + "operator-project_id", + "operator-service_account_name", + "operator-account", + "operator-bucket", + "operator-vm_zone", + "operator-vm_network", + "operator-boot_disk_size", + "operator-gpus", + "operator-machine_type", + "require-cc-operator" +]; + +function checkForCCEditChanges() { + // const hasChanges = fields.some(field => { + // return $(`#${field}`).val() !== window.defaultCCConfig[field]; + // }); + // TODO + hasChanges = true; + $('#apply-cc-operator-btn').prop('disabled', !hasChanges); +} + +function editCCConfig(editCCConfigBtn) { + const formData = new FormData($("#edit-cc-operator-form")[0]); + + // TODO: properly disable elements + disableElements("#profiles-form select, #profiles-form button"); + disableElements("#edit-config-form input, #edit-config-form button, #edit-config-form select"); + disableElements("#certificate-settings button"); + + ajaxRequest( + "/settings/edit_cc_operator", + "POST", + formData, + (response) => { + if (response.status === "success"){ + showReloadModal({ + title: "CC Configuration Edited Successfully", + seconds: 3 + }); + } + else { + showErrorModal("Failed to Edit CC Configuration", response); + } + }, + "Error editing CC Configuration:" + ); +} + + +$(document).ready(() => { + const checkbox = $("#require-cc-operator"); + checkbox.on("change", () => { + $("#edit-cc-operator-fields").toggle(checkbox.is(":checked")); + }); + $("#edit-cc-operator-fields").toggle(checkbox.is(":checked")); + + fields.forEach(field => $(`#${field}`).on('input', checkForCCEditChanges)); + checkForCCEditChanges(); + + $("#apply-cc-operator-btn").on("click", (e) => { + showConfirmModal(e.currentTarget, editCCConfig, "edit CC configuration?"); + }); + +}); diff --git a/cli/medperf/web_ui/templates/dataset/dataset_detail.html b/cli/medperf/web_ui/templates/dataset/dataset_detail.html index 78936505e..a0b423fba 100644 --- a/cli/medperf/web_ui/templates/dataset/dataset_detail.html +++ b/cli/medperf/web_ui/templates/dataset/dataset_detail.html @@ -5,6 +5,7 @@ {% import 'macros/container_macros.html' as container_macros %} {% import 'macros/model_macros.html' as model_macros %} {% import 'macros/benchmark_macros.html' as benchmark_macros %} +{% import 'macros/cc_asset_macro.html' as cc_asset_macro %} {% block title %}Dataset Details{% endblock title %} @@ -410,6 +411,10 @@

{% endfor %} {% endif %} +{% if is_owner %} + {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, dataset.id, "dataset", task_running) }} +{% endif %} + {% include "partials/panel_container.html" %} {% include "partials/text_content_container.html" %} {% include "partials/yaml_container.html" %} @@ -421,6 +426,7 @@

+ {% if task_running and task_formData.get("dataset_id", "") == dataset.id|string %} {% if request.app.state.task.name == "dataset_preparation" %} diff --git a/cli/medperf/web_ui/templates/macros/cc_asset_macro.html b/cli/medperf/web_ui/templates/macros/cc_asset_macro.html new file mode 100644 index 000000000..5c8f95816 --- /dev/null +++ b/cli/medperf/web_ui/templates/macros/cc_asset_macro.html @@ -0,0 +1,41 @@ +{# ./cli/medperf/web_ui/templates/macros/gcp_asset_macro.html #} + +{% macro gcp_asset(defaults, configured, entity_id, entity_type, task_running=false) %} +{% set fields = [ +("project_id", "GCP Project ID"), +("project_number", "GCP Project Number"), +("account", "GCP Account Email"), +("bucket", "GCP Bucket Name"), +("keyring_name", "GCP KMS Keyring Name"), +("key_name", "GCP KMS Key Name"), +("wip", "GCP Workload Identity Pool Name"), +] %} +
+
+
+
+
+ + +
+
+
+
+ {% for field_name, field_label in fields %} +
+ +
+ +
+
+ {% endfor %} +
+
+ +
+
+
+{% endmacro %} \ No newline at end of file diff --git a/cli/medperf/web_ui/templates/macros/cc_operator_macro.html b/cli/medperf/web_ui/templates/macros/cc_operator_macro.html new file mode 100644 index 000000000..2be15e843 --- /dev/null +++ b/cli/medperf/web_ui/templates/macros/cc_operator_macro.html @@ -0,0 +1,43 @@ +{# ./cli/medperf/web_ui/templates/macros/cc_operator_macro.html #} + +{% macro cc_operator(defaults, configured, task_running=false) %} +{% set fields = [ +("project_id", "GCP Project ID"), +("service_account_name", "GCP Service Account Name"), +("account", "GCP Account Email"), +("bucket", "GCP Bucket Name"), +("machine_type", "GCP Machine Type"), +("vm_zone", "VM Zone"), +("vm_network", "VM Network"), +("boot_disk_size", "VM Boot Disk Size (GB)"), +("gpus", "VM Number of GPUs"), +] %} +
+
+
+
+
+ + +
+
+
+
+ {% for field_name, field_label in fields %} +
+ +
+ +
+
+ {% endfor %} +
+
+ +
+
+
+{% endmacro %} \ No newline at end of file diff --git a/cli/medperf/web_ui/templates/model/model_detail.html b/cli/medperf/web_ui/templates/model/model_detail.html index 83046038a..6c4a65e63 100644 --- a/cli/medperf/web_ui/templates/model/model_detail.html +++ b/cli/medperf/web_ui/templates/model/model_detail.html @@ -4,6 +4,7 @@ {% import 'macros/association_card_macros.html' as association_card_macros %} {% import 'macros/model_detail_util_macro.html' as model_detail_util_macro %} +{% import 'macros/cc_asset_macro.html' as cc_asset_macro %} {% block title %}Model Details{% endblock %} @@ -157,6 +158,10 @@

Associated Benchmarks

{% endif %} +{% if is_owner %} + {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, entity.id, "model", task_running) }} +{% endif %} + {% include "partials/panel_container.html" %} {% include "partials/text_content_container.html" %} {% include "partials/yaml_container.html" %} @@ -166,6 +171,7 @@

Associated Benchmarks

{% block extra_js %} + {% if task_running and request.app.state.task.name == "model_association" and task_formData.get("model_id", "") == entity.id|string %} + {% endif %} {% endif %} + +{% if task_running and request.app.state.task.name == "data_update_cc_config" and task_formData.get("dataset_id", "") == entity.id|string %} + +{% endif %} + {% endblock extra_js %} diff --git a/cli/medperf/web_ui/templates/dataset/register_dataset.html b/cli/medperf/web_ui/templates/dataset/register_dataset.html index 2c07f045f..5078b5d58 100644 --- a/cli/medperf/web_ui/templates/dataset/register_dataset.html +++ b/cli/medperf/web_ui/templates/dataset/register_dataset.html @@ -151,6 +151,14 @@

Register a New Dataset

> +
+
+
+ + +
+
+
+ {% endif %} {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, entity.id, "model", task_running) }} {% endif %} @@ -183,4 +193,16 @@

Associated Benchmarks

}); {% endif %} + +{% if task_running and request.app.state.task.name == "model_update_cc_config" and task_formData.get("model_id", "") == entity.id|string %} + +{% endif %} {% endblock %} \ No newline at end of file From 72805c30bf53c90c31213a0a9b16d081df89f350 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 11 Mar 2026 21:41:21 +0100 Subject: [PATCH 22/72] add file exists api for cc --- cli/medperf/asset_management/asset_management.py | 5 +++++ cli/medperf/asset_management/cc_operator.py | 10 ++++++++++ cli/medperf/asset_management/gcp_utils.py | 13 ++++++++++++- 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/cli/medperf/asset_management/asset_management.py b/cli/medperf/asset_management/asset_management.py index 7b333fedc..3312de2a4 100644 --- a/cli/medperf/asset_management/asset_management.py +++ b/cli/medperf/asset_management/asset_management.py @@ -186,3 +186,8 @@ def download_results( operator_manager = OperatorManager(operator_cc_config) operator_manager.download_results(workload, private_key_bytes, results_path) + + +def workload_results_exists(operator_cc_config: dict, workload: CCWorkloadID) -> bool: + operator_manager = OperatorManager(operator_cc_config) + return operator_manager.results_exist(workload) diff --git a/cli/medperf/asset_management/cc_operator.py b/cli/medperf/asset_management/cc_operator.py index 4a740353a..42e2cef20 100644 --- a/cli/medperf/asset_management/cc_operator.py +++ b/cli/medperf/asset_management/cc_operator.py @@ -3,6 +3,7 @@ GCPOperatorConfig, CCWorkloadID, download_file_from_gcs, + check_gcs_file_exists, run_workload, run_gpu_workload, wait_for_workload_completion, @@ -79,6 +80,15 @@ def run_workload( def wait_for_workload_completion(self, workload: CCWorkloadID): wait_for_workload_completion(self.config, workload) + def results_exist(self, workload: CCWorkloadID): + results_exist = check_gcs_file_exists(self.config, workload.results_path) + if not results_exist: + return False + decryption_key_exists = check_gcs_file_exists( + self.config, workload.results_encryption_key_path + ) + return decryption_key_exists + def download_results( self, workload: CCWorkloadID, diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/gcp_utils.py index d99638146..8068de432 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/gcp_utils.py @@ -5,7 +5,7 @@ from medperf.utils import run_command from google.cloud import kms from google.iam.v1 import policy_pb2 -from google.cloud import compute_v1 +from google.cloud import compute_v1, storage import time from colorama import Fore, Style import medperf.config as medperf_config @@ -23,6 +23,7 @@ class CCWorkloadID(BaseModel): data_id: int model_id: int script_id: int + execution_id: int = None @property def id(self): @@ -46,6 +47,8 @@ def id_for_model(self): @property def human_readable_id(self): + if self.execution_id: + return f"d{self.data_id}-m{self.model_id}-s{self.script_id}-e{self.execution_id}" return f"d{self.data_id}-m{self.model_id}-s{self.script_id}" @property @@ -217,6 +220,14 @@ def download_file_from_gcs( run_command(cmd) +def check_gcs_file_exists( + config: Union[GCPAssetConfig, GCPOperatorConfig], gcs_path: str +) -> bool: + client = storage.Client() + bucket = client.bucket(config.bucket) + return bucket.blob(gcs_path).exists() + + # run def run_workload( config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: str From 92e2433c464d7980010a2a2643bde9f74cae52ee Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 11 Mar 2026 21:41:34 +0100 Subject: [PATCH 23/72] add is_script api for containers --- cli/medperf/containers/parsers/parser.py | 4 ++++ cli/medperf/containers/parsers/simple_container.py | 3 +++ cli/medperf/entities/cube.py | 3 +++ 3 files changed, 10 insertions(+) diff --git a/cli/medperf/containers/parsers/parser.py b/cli/medperf/containers/parsers/parser.py index 87bc8d097..98c5c1af2 100644 --- a/cli/medperf/containers/parsers/parser.py +++ b/cli/medperf/containers/parsers/parser.py @@ -49,3 +49,7 @@ def is_docker_image(self): @abstractmethod def is_model_container(self): pass + + @abstractmethod + def is_script_container(self): + pass diff --git a/cli/medperf/containers/parsers/simple_container.py b/cli/medperf/containers/parsers/simple_container.py index 02ec653f6..1404833b7 100644 --- a/cli/medperf/containers/parsers/simple_container.py +++ b/cli/medperf/containers/parsers/simple_container.py @@ -118,3 +118,6 @@ def is_docker_image(self): def is_model_container(self): return "infer" in self.container_config["tasks"] + + def is_script_container(self): + return "run_script" in self.container_config["tasks"] diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index c776c0fa1..b66183192 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -101,6 +101,9 @@ def is_encrypted(self) -> bool: def is_model(self) -> bool: return self.parser.is_model_container() + def is_script(self) -> bool: + return self.parser.is_script_container() + @staticmethod def remote_prefilter(filters: dict): """Applies filtering logic that must be done before retrieving remote entities From 5d6521c147ba8e12d78adcfbddce4d95768af8e5 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 11 Mar 2026 22:54:02 +0100 Subject: [PATCH 24/72] add script execution without metrics --- .../confidential_model_container_execution.py | 186 ++++++++++++++++++ .../commands/execution/container_execution.py | 24 +-- .../commands/execution/execution_flow.py | 24 ++- 3 files changed, 221 insertions(+), 13 deletions(-) create mode 100644 cli/medperf/commands/execution/confidential_model_container_execution.py diff --git a/cli/medperf/commands/execution/confidential_model_container_execution.py b/cli/medperf/commands/execution/confidential_model_container_execution.py new file mode 100644 index 000000000..e664c9561 --- /dev/null +++ b/cli/medperf/commands/execution/confidential_model_container_execution.py @@ -0,0 +1,186 @@ +import base64 + +from medperf.asset_management.gcp_utils import CCWorkloadID +from medperf.entities.cube import Cube +from medperf.entities.model import Model +from medperf.entities.dataset import Dataset +from medperf.entities.execution import Execution +from medperf.entities.certificate import Certificate +import medperf.config as config +from medperf.exceptions import DecryptionError, ExecutionError + +from medperf.account_management import get_medperf_user_object +from medperf.asset_management.asset_management import ( + run_workload, + download_results, + workload_results_exists, +) +from medperf.utils import get_string_hash +from medperf.commands.certificate.utils import load_user_private_key +from medperf.commands.execution.container_execution import ContainerExecution + + +class ConfidentialModelContainerExecution: + @classmethod + def run( + cls, + benchmark_id: int, + dataset: Dataset, + model: Model, + script: Cube, + evaluator: Cube, + execution: Execution = None, + ignore_model_errors=False, + ): + """Benchmark execution flow. + + Args: + benchmark_uid (int): UID of the desired benchmark + data_uid (str): Registered Dataset UID + model_uid (int): UID of model to execute + """ + execution_flow = cls( + benchmark_id, + dataset, + model, + script, + evaluator, + execution, + ignore_model_errors, + ) + execution_flow.setup_local_environment() + with config.ui.interactive(): + execution_flow.get_operator() + execution_flow.validate() + execution_flow.prepare() + execution_flow.setup_workload() + if execution_flow.should_run_workload(): + execution_flow.run_workload() + execution_flow.download_predictions() + execution_flow.run_evaluation() + execution_summary = execution_flow.todict() + return execution_summary + + def __init__( + self, + benchmark_id: int, + dataset: Dataset, + model: Model, + script: Cube, + evaluator: Cube, + execution: Execution = None, + ignore_model_errors=False, + ): + self.comms = config.comms + self.ui = config.ui + self.benchmark_id = benchmark_id + self.dataset = dataset + self.model = model + self.script = script + self.evaluator = evaluator + self.execution = execution + self.ignore_model_errors = ignore_model_errors + self.operator = None + self.dataset_cc_config = None + self.model_cc_config = None + self.operator_cc_config = None + self.local_execution_flow = None + + def setup_local_environment(self): + self.local_execution_flow = ContainerExecution( + self.dataset, + self.model, + self.evaluator, + self.execution, + self.ignore_model_errors, + ) + self.local_execution_flow.prepare() + + def get_operator(self): + self.operator = get_medperf_user_object() + + def validate(self): + if not self.dataset.is_cc_configured(): + raise ExecutionError( + f"Dataset {self.dataset.id} is not configured for confidential computing." + ) + if not self.model.is_cc_configured(): + raise ExecutionError( + f"Model {self.model.id} is not configured for confidential computing." + ) + if not self.operator.is_cc_configured(): + raise ExecutionError( + "User does not have a configuration to operate a confidential execution." + ) + + def prepare(self): + self.dataset_cc_config = self.dataset.get_cc_config() + self.model_cc_config = self.model.get_cc_config() + self.operator_cc_config = self.operator.get_cc_config() + self.asset = self.model.asset_obj + + def setup_workload(self): + if self.dataset.owner == self.operator.id: + cert_obj = Certificate.get_user_certificate() + else: + datasets_certs = config.comms.get_benchmark_datasets_certificates( + self.benchmark_id + ) + for cert in datasets_certs: + if cert["owner"]["id"] == self.dataset.owner: + cert.pop("owner") + cert_obj = Certificate(**cert) + break + else: + raise ExecutionError("Dataset not associated.") + + public_key_bytes = cert_obj.public_key() + result_collector_public_key = base64.b64encode(public_key_bytes) + workload = CCWorkloadID( + data_hash=self.dataset.generated_uid, + model_hash=self.asset.asset_hash, + script_hash=self.script.image_hash, + result_collector_hash=get_string_hash(result_collector_public_key), + data_id=self.dataset.id, + model_id=self.asset.id, + script_id=self.script.id, + execution_id=self.execution.id, + ) + + self.workload = workload + self.result_collector_public_key = result_collector_public_key + + def should_run_workload(self): + return not workload_results_exists(self.operator_cc_config, self.workload) + + def run_workload(self): + config.ui.text = "Running CC workload..." + docker_image = self.script.parser.get_setup_args() + # TODO: docker.io/ + docker_image = "docker.io/" + docker_image + run_workload( + docker_image, + self.workload, + self.dataset_cc_config, + self.model_cc_config, + self.operator_cc_config, + self.result_collector_public_key.decode("utf-8"), + ) + + def download_predictions(self): + config.ui.text = "Downloading results..." + results_path = self.local_execution_flow.preds_path + private_key_bytes = load_user_private_key() + if private_key_bytes is None: + raise DecryptionError("Missing Private Key") + + # TODO: results_path may contain root name + download_results( + self.operator_cc_config, self.workload, private_key_bytes, results_path + ) + + def run_evaluation(self): + return self.local_execution_flow.run_evaluation() + + def todict(self): + return self.local_execution_flow.todict() diff --git a/cli/medperf/commands/execution/container_execution.py b/cli/medperf/commands/execution/container_execution.py index 8088a2ac7..6e27f1c4a 100644 --- a/cli/medperf/commands/execution/container_execution.py +++ b/cli/medperf/commands/execution/container_execution.py @@ -102,8 +102,8 @@ def __setup_local_outputs_path(self): return local_outputs_path def set_pending_status(self): - self.__send_model_report("pending") - self.__send_evaluator_report("pending") + self.send_model_report("pending") + self.send_evaluator_report("pending") def run_inference(self): self.ui.text = f"Running inference of model '{self.model.name}' on dataset" @@ -112,7 +112,7 @@ def run_inference(self): "data_path": self.dataset.data_path, "output_path": self.preds_path, } - self.__send_model_report("started") + self.send_model_report("started") try: self.model.run( task="infer", @@ -123,7 +123,7 @@ def run_inference(self): self.ui.print("> Model execution complete") except ExecutionError as e: - self.__send_model_report("failed") + self.send_model_report("failed") if not self.ignore_model_errors: logging.error(f"Model Execution failed: {e}") raise ExecutionError(f"Model Execution failed: {e}") @@ -133,9 +133,9 @@ def run_inference(self): return except KeyboardInterrupt: logging.warning("Model Execution interrupted by user") - self.__send_model_report("interrupted") + self.send_model_report("interrupted") raise CleanExit("Model Execution interrupted by user") - self.__send_model_report("finished") + self.send_model_report("finished") def run_evaluation(self): self.ui.text = f"Calculating metrics for model '{self.model.name}' predictions" @@ -146,7 +146,7 @@ def run_evaluation(self): "output_path": self.results_path, "local_outputs_path": self.local_outputs_path, } - self.__send_evaluator_report("started") + self.send_evaluator_report("started") try: self.evaluator.run( task="evaluate", @@ -156,13 +156,13 @@ def run_evaluation(self): ) except ExecutionError as e: logging.error(f"Metrics calculation failed: {e}") - self.__send_evaluator_report("failed") + self.send_evaluator_report("failed") raise ExecutionError(f"Metrics calculation failed: {e}") except KeyboardInterrupt: logging.warning("Metrics calculation interrupted by user") - self.__send_evaluator_report("interrupted") + self.send_evaluator_report("interrupted") raise CleanExit("Metrics calculation interrupted by user") - self.__send_evaluator_report("finished") + self.send_evaluator_report("finished") def todict(self): return { @@ -179,10 +179,10 @@ def get_results(self): raise ExecutionError("Results file is empty") return results - def __send_model_report(self, status: str): + def send_model_report(self, status: str): self.__send_report("model_report", status) - def __send_evaluator_report(self, status: str): + def send_evaluator_report(self, status: str): self.__send_report("evaluation_report", status) def __send_report(self, field: str, status: str): diff --git a/cli/medperf/commands/execution/execution_flow.py b/cli/medperf/commands/execution/execution_flow.py index 867b746f3..730e6a411 100644 --- a/cli/medperf/commands/execution/execution_flow.py +++ b/cli/medperf/commands/execution/execution_flow.py @@ -2,10 +2,14 @@ from medperf.entities.model import Model from medperf.entities.dataset import Dataset from medperf.entities.execution import Execution +from medperf.entities.benchmark import Benchmark from medperf.enums import ModelType from medperf.commands.execution.container_execution import ContainerExecution from medperf.commands.execution.script_execution import ScriptExecution from medperf.commands.execution.confidential_execution import ConfidentialExecution +from medperf.commands.execution.confidential_model_container_execution import ( + ConfidentialModelContainerExecution, +) from medperf.account_management import get_medperf_user_data, is_user_logged_in @@ -25,13 +29,31 @@ def run( ) if ( - model.type == ModelType.ASSET.value + evaluator.is_script() + and model.type == ModelType.ASSET.value and model.requires_cc() and not user_is_model_owner ): return ConfidentialExecution.run( benchmark_id, dataset, model, evaluator, execution, ignore_model_errors ) + elif ( + model.type == ModelType.ASSET.value + and model.requires_cc() + and not user_is_model_owner + ): + benchmark = Benchmark.get(benchmark_id) + model = Model.get(benchmark.reference_model) + script = model.container_obj + return ConfidentialModelContainerExecution.run( + benchmark_id, + dataset, + model, + script, + evaluator, + execution, + ignore_model_errors, + ) elif model.type == ModelType.ASSET.value: asset = model.asset_obj asset.prepare_asset_files() From 3c1c90cf47b895f0a5a8ee1c34121cf9da7fb1c5 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 11 Mar 2026 23:05:56 +0100 Subject: [PATCH 25/72] fix docker.io prefix problem --- .../commands/execution/confidential_execution.py | 4 ++-- .../confidential_model_container_execution.py | 4 ++-- cli/medperf/containers/runners/docker_utils.py | 13 +++++++++++++ cli/requirements.txt | 1 + 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/cli/medperf/commands/execution/confidential_execution.py b/cli/medperf/commands/execution/confidential_execution.py index 08d284dcd..44a66ce5a 100644 --- a/cli/medperf/commands/execution/confidential_execution.py +++ b/cli/medperf/commands/execution/confidential_execution.py @@ -18,6 +18,7 @@ from medperf.asset_management.asset_management import run_workload, download_results from medperf.utils import get_string_hash from medperf.commands.certificate.utils import load_user_private_key +from medperf.containers.runners.docker_utils import full_docker_image_name class ConfidentialExecution: @@ -132,8 +133,7 @@ def setup_workload(self): def run_workload(self): config.ui.text = "Running CC workload..." docker_image = self.script.parser.get_setup_args() - # TODO: docker.io/ - docker_image = "docker.io/" + docker_image + docker_image = full_docker_image_name(docker_image) run_workload( docker_image, self.workload, diff --git a/cli/medperf/commands/execution/confidential_model_container_execution.py b/cli/medperf/commands/execution/confidential_model_container_execution.py index e664c9561..1b33f1708 100644 --- a/cli/medperf/commands/execution/confidential_model_container_execution.py +++ b/cli/medperf/commands/execution/confidential_model_container_execution.py @@ -18,6 +18,7 @@ from medperf.utils import get_string_hash from medperf.commands.certificate.utils import load_user_private_key from medperf.commands.execution.container_execution import ContainerExecution +from medperf.containers.runners.docker_utils import full_docker_image_name class ConfidentialModelContainerExecution: @@ -156,8 +157,7 @@ def should_run_workload(self): def run_workload(self): config.ui.text = "Running CC workload..." docker_image = self.script.parser.get_setup_args() - # TODO: docker.io/ - docker_image = "docker.io/" + docker_image + docker_image = full_docker_image_name(docker_image) run_workload( docker_image, self.workload, diff --git a/cli/medperf/containers/runners/docker_utils.py b/cli/medperf/containers/runners/docker_utils.py index 38367e217..856e5cf06 100644 --- a/cli/medperf/containers/runners/docker_utils.py +++ b/cli/medperf/containers/runners/docker_utils.py @@ -8,6 +8,7 @@ import json import tarfile import logging +from docker.auth import resolve_repository_name def get_docker_image_hash(docker_image, timeout: int = None): @@ -177,3 +178,15 @@ def delete_images(images): run_command(delete_image_cmd) except ExecutionError: config.ui.print_warning("WARNING: Failed to delete docker images.") + + +def full_docker_image_name(image_name: str) -> str: + """ + Returns the full docker image name with registry. + If the image name does not contain a registry, it is assumed to be docker.io. + """ + logging.debug(f"Resolving full docker image name for {image_name}") + registry, name = resolve_repository_name(image_name) + resolved_name = f"{registry}/{name}" + logging.debug(f"Resolved docker image name: {resolved_name}") + return resolved_name diff --git a/cli/requirements.txt b/cli/requirements.txt index 506f82877..9fb797b62 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -28,6 +28,7 @@ fastapi==0.111.1 fastapi-login==1.10.2 cryptography==46.0.3 click==8.1.8 +docker==7.1.0 google-auth google-cloud-storage From 4fb74ddea9ff6992e33c8215a5e7fefe3d272527 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 00:55:25 +0000 Subject: [PATCH 26/72] update how tar works in cc base image --- examples/cc/base_image/src/store_results.py | 2 +- examples/cc/base_image/src/utils.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/cc/base_image/src/store_results.py b/examples/cc/base_image/src/store_results.py index ed4014988..bdb857fe6 100644 --- a/examples/cc/base_image/src/store_results.py +++ b/examples/cc/base_image/src/store_results.py @@ -20,7 +20,7 @@ def store_results(args) -> None: os.makedirs(tmp_files, exist_ok=True) tmp_result_archive = os.path.join(tmp_files, "result.tar.gz") - tar(folders_paths=[result_files_path], output_path=tmp_result_archive) + tar(output_path=tmp_result_archive, folder_path=result_files_path) # encrypt file encryption_key_file = os.path.join(tmp_files, "tmp_encryption_key") diff --git a/examples/cc/base_image/src/utils.py b/examples/cc/base_image/src/utils.py index c17491095..8bcad4485 100644 --- a/examples/cc/base_image/src/utils.py +++ b/examples/cc/base_image/src/utils.py @@ -23,11 +23,9 @@ def untar(filepath: str, extract_to: str) -> None: os.remove(filepath) -def tar(output_path: str, folders_paths: list[str]) -> None: +def tar(output_path: str, folder_path: str) -> None: logging.info(f"Compressing tar.gz at {output_path}") tar_arc = tarfile.open(output_path, "w:gz") - for folder in folders_paths: - arcname = os.path.basename(folder) - tar_arc.add(folder, arcname=arcname) - logging.info(f"Compressing tar.gz at {output_path}: {folder} Added.") + tar_arc.add(folder_path, arcname="") + logging.info(f"Compressing tar.gz at {output_path}: {folder_path} Added.") tar_arc.close() From 7a0f62c3230c2bb1361d40ce159e4562d0763651 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 00:57:05 +0000 Subject: [PATCH 27/72] expect wip_provider in gcp config in cc base image --- examples/cc/base_image/src/assets/gcp/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cc/base_image/src/assets/gcp/utils.py b/examples/cc/base_image/src/assets/gcp/utils.py index 759c8243d..1ca05522d 100644 --- a/examples/cc/base_image/src/assets/gcp/utils.py +++ b/examples/cc/base_image/src/assets/gcp/utils.py @@ -31,6 +31,7 @@ class GCPAssetConfig: key_name: str key_location: str wip: str + wip_provider: str @property def full_key_name(self) -> str: @@ -38,7 +39,7 @@ def full_key_name(self) -> str: @property def full_wip_name(self) -> str: - return f"projects/{self.project_number}/locations/global/workloadIdentityPools/{self.wip}/providers/attestation-verifier" + return f"projects/{self.project_number}/locations/global/workloadIdentityPools/{self.wip}/providers/{self.wip_provider}" @dataclass From cedf40b57b8ee529e242ba2d282f55d3c814dadb Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 01:03:52 +0000 Subject: [PATCH 28/72] make wip_provider required in client --- cli/medperf/asset_management/__gcp_util_archive.py | 2 +- cli/medperf/asset_management/gcp_utils.py | 5 +++-- examples/cc/chestxray/dataset_cc_config.json | 3 ++- examples/cc/chestxray/model_cc_config.json | 3 ++- examples/cc/rano/dataset_cc_config.json | 3 ++- examples/cc/rano/model_cc_config.json | 3 ++- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cli/medperf/asset_management/__gcp_util_archive.py b/cli/medperf/asset_management/__gcp_util_archive.py index 10805fa59..cdb04cfe1 100644 --- a/cli/medperf/asset_management/__gcp_util_archive.py +++ b/cli/medperf/asset_management/__gcp_util_archive.py @@ -137,7 +137,7 @@ def create_workload_identity_pool_oidc_provider( "workload-identity-pools", "providers", "create-oidc", - "attestation-verifier", + config.wip_provider, "--location=global", f"--workload-identity-pool={config.wip}", "--issuer-uri=https://confidentialcomputing.googleapis.com/", diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/gcp_utils.py index 8068de432..5271e3ca4 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/gcp_utils.py @@ -113,6 +113,7 @@ class GCPAssetConfig(BaseModel): key_name: str key_location: str wip: str + wip_provider: str @property def full_key_name(self) -> str: @@ -125,7 +126,7 @@ def full_key_name(self) -> str: def full_wip_provider_name(self) -> str: return ( f"projects/{self.project_number}/locations/global/" - f"workloadIdentityPools/{self.wip}/providers/attestation-verifier" + f"workloadIdentityPools/{self.wip}/providers/{self.wip_provider}" ) @property @@ -179,7 +180,7 @@ def update_workload_identity_pool_oidc_provider( "workload-identity-pools", "providers", "update-oidc", - "attestation-verifier", + config.wip_provider, "--location=global", f"--workload-identity-pool={config.wip}", f"--attribute-mapping={attribute_mapping}", diff --git a/examples/cc/chestxray/dataset_cc_config.json b/examples/cc/chestxray/dataset_cc_config.json index 044aa2d56..27562278f 100644 --- a/examples/cc/chestxray/dataset_cc_config.json +++ b/examples/cc/chestxray/dataset_cc_config.json @@ -5,5 +5,6 @@ "keyring_name": "data-owner-keyring", "key_name": "data-owner-key2", "key_location": "us-west1", - "wip": "data-owner-wip" + "wip": "data-owner-wip", + "wip_provider": "attestation-verifier" } \ No newline at end of file diff --git a/examples/cc/chestxray/model_cc_config.json b/examples/cc/chestxray/model_cc_config.json index 6fcc6fff3..f932529ff 100644 --- a/examples/cc/chestxray/model_cc_config.json +++ b/examples/cc/chestxray/model_cc_config.json @@ -5,5 +5,6 @@ "keyring_name": "model-owner-keyring", "key_name": "model-owner-key", "key_location": "us-west1", - "wip": "model-owner-wip" + "wip": "model-owner-wip", + "wip_provider": "attestation-verifier" } \ No newline at end of file diff --git a/examples/cc/rano/dataset_cc_config.json b/examples/cc/rano/dataset_cc_config.json index 044aa2d56..27562278f 100644 --- a/examples/cc/rano/dataset_cc_config.json +++ b/examples/cc/rano/dataset_cc_config.json @@ -5,5 +5,6 @@ "keyring_name": "data-owner-keyring", "key_name": "data-owner-key2", "key_location": "us-west1", - "wip": "data-owner-wip" + "wip": "data-owner-wip", + "wip_provider": "attestation-verifier" } \ No newline at end of file diff --git a/examples/cc/rano/model_cc_config.json b/examples/cc/rano/model_cc_config.json index 6fcc6fff3..f932529ff 100644 --- a/examples/cc/rano/model_cc_config.json +++ b/examples/cc/rano/model_cc_config.json @@ -5,5 +5,6 @@ "keyring_name": "model-owner-keyring", "key_name": "model-owner-key", "key_location": "us-west1", - "wip": "model-owner-wip" + "wip": "model-owner-wip", + "wip_provider": "attestation-verifier" } \ No newline at end of file From 1eb5492c708497bb8aa799f96ad7521922586b5c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 01:43:31 +0000 Subject: [PATCH 29/72] prevent an unsupported option for execution --- cli/medperf/commands/execution/execution_flow.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cli/medperf/commands/execution/execution_flow.py b/cli/medperf/commands/execution/execution_flow.py index 730e6a411..997445afe 100644 --- a/cli/medperf/commands/execution/execution_flow.py +++ b/cli/medperf/commands/execution/execution_flow.py @@ -11,6 +11,7 @@ ConfidentialModelContainerExecution, ) from medperf.account_management import get_medperf_user_data, is_user_logged_in +from medperf.exceptions import ExecutionError class ExecutionFlow: @@ -55,6 +56,10 @@ def run( ignore_model_errors, ) elif model.type == ModelType.ASSET.value: + if not evaluator.is_script(): + raise ExecutionError( + "Running a model container with another asset model is not supported yet." + ) asset = model.asset_obj asset.prepare_asset_files() return ScriptExecution.run( From 07f17535aad3e90ccc35f8ae2904e0a1a3949c2a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 01:43:47 +0000 Subject: [PATCH 30/72] add model-only rano container --- .../rano/implementation/Dockerfile.modelonly | 17 +++++ .../benchmark/entrypoint_modelonly.sh | 71 +++++++++++++++++++ .../cc/rano/implementation/build_modelonly.sh | 2 + .../container_config_modelonly.yaml | 15 ++++ 4 files changed, 105 insertions(+) create mode 100644 examples/cc/rano/implementation/Dockerfile.modelonly create mode 100644 examples/cc/rano/implementation/benchmark/entrypoint_modelonly.sh create mode 100644 examples/cc/rano/implementation/build_modelonly.sh create mode 100644 examples/cc/rano/implementation/container_config_modelonly.yaml diff --git a/examples/cc/rano/implementation/Dockerfile.modelonly b/examples/cc/rano/implementation/Dockerfile.modelonly new file mode 100644 index 000000000..09cf10ce4 --- /dev/null +++ b/examples/cc/rano/implementation/Dockerfile.modelonly @@ -0,0 +1,17 @@ +FROM mlcommons/medperf-confidential-benchmark-base:0.0.0 + +ENV CUDA_VISIBLE_DEVICES="0" + + +# install project dependencies +RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 + +COPY ./requirements.txt /project/requirements.txt + +RUN pip install --no-cache-dir -r /project/requirements.txt + +# Copy project folder +COPY ./benchmark /project/benchmark +RUN rm -rf /project/benchmark/metrics && mv /project/benchmark/entrypoint_modelonly.sh /project/benchmark/entrypoint.sh diff --git a/examples/cc/rano/implementation/benchmark/entrypoint_modelonly.sh b/examples/cc/rano/implementation/benchmark/entrypoint_modelonly.sh new file mode 100644 index 000000000..f199c52f2 --- /dev/null +++ b/examples/cc/rano/implementation/benchmark/entrypoint_modelonly.sh @@ -0,0 +1,71 @@ +#!/bin/bash +set -eo pipefail + +# Default values +INPUT_DATA="" +INPUT_LABELS="" +MODEL_FILES="" +OUTPUT_RESULTS="" + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --input-data) + INPUT_DATA="$2" + shift 2 + ;; + --input-labels) + INPUT_LABELS="$2" + shift 2 + ;; + --model-files) + MODEL_FILES="$2" + shift 2 + ;; + --output-results) + OUTPUT_RESULTS="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Validate required arguments +if [[ -z "$INPUT_DATA" ]]; then + echo "Error: --input-data is required" + exit 1 +fi + +if [[ -z "$INPUT_LABELS" ]]; then + echo "Error: --input-labels is required" + exit 1 +fi + +if [[ -z "$MODEL_FILES" ]]; then + echo "Error: --model-files is required" + exit 1 +fi + +if [[ -z "$OUTPUT_RESULTS" ]]; then + echo "Error: --output-results is required" + exit 1 +fi + +# run benchmark +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +export USER=appuser + +start_time=$(date +%s) + +bash $SCRIPT_DIR/inference/entrypoint.sh \ + --postopp_pardir "$INPUT_DATA" \ + --inference_output_dir "$OUTPUT_RESULTS" \ + --source_plans_dir "$MODEL_FILES" + +end_time=$(date +%s) +elapsed=$((end_time - start_time)) + +echo "Inference step took ${elapsed} seconds" diff --git a/examples/cc/rano/implementation/build_modelonly.sh b/examples/cc/rano/implementation/build_modelonly.sh new file mode 100644 index 000000000..173e46b1c --- /dev/null +++ b/examples/cc/rano/implementation/build_modelonly.sh @@ -0,0 +1,2 @@ +docker build -t mlcommons/medperf-cc-rano-modelonly:0.0.0 -f Dockerfile.modelonly . +docker push mlcommons/medperf-cc-rano-modelonly:0.0.0 \ No newline at end of file diff --git a/examples/cc/rano/implementation/container_config_modelonly.yaml b/examples/cc/rano/implementation/container_config_modelonly.yaml new file mode 100644 index 000000000..39a7a3254 --- /dev/null +++ b/examples/cc/rano/implementation/container_config_modelonly.yaml @@ -0,0 +1,15 @@ +container_type: DockerImage +image: mlcommons/medperf-cc-rano-modelonly:0.0.0 +tasks: + infer: + input_volumes: + data_path: + mount_path: /mlcommons/volumes/data + type: directory + additional_files: + mount_path: /mlcommons/volumes/model_files + type: directory + output_volumes: + output_path: + mount_path: /mlcommons/volumes/results + type: directory From 2e3e8f1e2f0d979f74864f9c47b4fdc75fc370f8 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 01:54:44 +0000 Subject: [PATCH 31/72] fix json yaml issue in dev data --- .../cc/rano/metrics_container_config.yaml | 18 ++ examples/cc/rano/prep_container_config.yaml | 190 ++++++++---------- examples/cc/rano/prep_parameters.yaml | 46 ++--- 3 files changed, 128 insertions(+), 126 deletions(-) create mode 100644 examples/cc/rano/metrics_container_config.yaml diff --git a/examples/cc/rano/metrics_container_config.yaml b/examples/cc/rano/metrics_container_config.yaml new file mode 100644 index 000000000..d404095bb --- /dev/null +++ b/examples/cc/rano/metrics_container_config.yaml @@ -0,0 +1,18 @@ +container_type: DockerImage +image: mlcommons/rano-metrics:0.0.0 +tasks: + evaluate: + input_volumes: + labels: + mount_path: /mlcommons/volumes/labels + type: directory + predictions: + mount_path: /mlcommons/volumes/predictions + type: directory + output_volumes: + local_outputs_path: + mount_path: /mlcommons/volumes/local_outputs + type: directory + output_path: + mount_path: /mlcommons/volumes/results/results.yaml + type: file diff --git a/examples/cc/rano/prep_container_config.yaml b/examples/cc/rano/prep_container_config.yaml index 8841bfc9e..c9c4abcd3 100644 --- a/examples/cc/rano/prep_container_config.yaml +++ b/examples/cc/rano/prep_container_config.yaml @@ -1,101 +1,89 @@ -{ - "image": "mlcommons/rano-data-prep-mlcube:1.0.11", - "tasks": - { - "prepare": - { - "run_args": - { - "command": - [ - "prepare", - "--data_path=/mlcube_io0/", - "--labels_path=/mlcube_io1/", - "--parameters_file=/mlcube_io2/parameters.yaml", - "--models=/mlcube_io3/models", - "--output_path=/mlcube_io4/", - "--output_labels_path=/mlcube_io5/", - "--report_file=/mlcube_io6/report.yaml", - "--metadata_path=/mlcube_io7/", - ], - }, - "input_volumes": - { - "data_path": { "type": "directory", "mount_path": "/mlcube_io0" }, - "labels_path": - { "type": "directory", "mount_path": "/mlcube_io1" }, - "parameters_file": - { "type": "file", "mount_path": "/mlcube_io2/parameters.yaml" }, - "additional_files": - { "type": "directory", "mount_path": "/mlcube_io3" }, - }, - "output_volumes": - { - "output_path": - { "type": "directory", "mount_path": "/mlcube_io4" }, - "report_file": - { "type": "file", "mount_path": "/mlcube_io6/report.yaml" }, - "metadata_path": - { "type": "directory", "mount_path": "/mlcube_io7" }, - "output_labels_path": - { "type": "directory", "mount_path": "/mlcube_io5" }, - }, - }, - "statistics": - { - "run_args": - { - "command": - [ - "statistics", - "--data_path=/mlcube_io0/", - "--labels_path=/mlcube_io1/", - "--parameters_file=/mlcube_io2/parameters.yaml", - "--metadata_path=/mlcube_io3/", - "--output_path=/mlcube_io4/statistics.yaml", - ], - }, - "input_volumes": - { - "data_path": { "type": "directory", "mount_path": "/mlcube_io0" }, - "labels_path": - { "type": "directory", "mount_path": "/mlcube_io1" }, - "metadata_path": - { "type": "directory", "mount_path": "/mlcube_io3" }, - "parameters_file": - { "type": "file", "mount_path": "/mlcube_io2/parameters.yaml" }, - }, - "output_volumes": - { - "output_path": - { "type": "file", "mount_path": "/mlcube_io4/statistics.yaml" }, - }, - }, - "sanity_check": - { - "run_args": - { - "command": - [ - "sanity_check", - "--data_path=/mlcube_io0/", - "--labels_path=/mlcube_io1/", - "--parameters_file=/mlcube_io2/parameters.yaml", - "--metadata_path=/mlcube_io3/", - ], - }, - "input_volumes": - { - "data_path": { "type": "directory", "mount_path": "/mlcube_io0" }, - "labels_path": - { "type": "directory", "mount_path": "/mlcube_io1" }, - "metadata_path": - { "type": "directory", "mount_path": "/mlcube_io3" }, - "parameters_file": - { "type": "file", "mount_path": "/mlcube_io2/parameters.yaml" }, - }, - "output_volumes": {}, - }, - }, - "container_type": "DockerImage", -} +container_type: DockerImage +image: mlcommons/rano-data-prep-mlcube:1.0.11 +tasks: + prepare: + input_volumes: + additional_files: + mount_path: /mlcube_io3 + type: directory + data_path: + mount_path: /mlcube_io0 + type: directory + labels_path: + mount_path: /mlcube_io1 + type: directory + parameters_file: + mount_path: /mlcube_io2/parameters.yaml + type: file + output_volumes: + metadata_path: + mount_path: /mlcube_io7 + type: directory + output_labels_path: + mount_path: /mlcube_io5 + type: directory + output_path: + mount_path: /mlcube_io4 + type: directory + report_file: + mount_path: /mlcube_io6/report.yaml + type: file + run_args: + command: + - prepare + - --data_path=/mlcube_io0/ + - --labels_path=/mlcube_io1/ + - --parameters_file=/mlcube_io2/parameters.yaml + - --models=/mlcube_io3/models + - --output_path=/mlcube_io4/ + - --output_labels_path=/mlcube_io5/ + - --report_file=/mlcube_io6/report.yaml + - --metadata_path=/mlcube_io7/ + sanity_check: + input_volumes: + data_path: + mount_path: /mlcube_io0 + type: directory + labels_path: + mount_path: /mlcube_io1 + type: directory + metadata_path: + mount_path: /mlcube_io3 + type: directory + parameters_file: + mount_path: /mlcube_io2/parameters.yaml + type: file + output_volumes: {} + run_args: + command: + - sanity_check + - --data_path=/mlcube_io0/ + - --labels_path=/mlcube_io1/ + - --parameters_file=/mlcube_io2/parameters.yaml + - --metadata_path=/mlcube_io3/ + statistics: + input_volumes: + data_path: + mount_path: /mlcube_io0 + type: directory + labels_path: + mount_path: /mlcube_io1 + type: directory + metadata_path: + mount_path: /mlcube_io3 + type: directory + parameters_file: + mount_path: /mlcube_io2/parameters.yaml + type: file + output_volumes: + output_path: + mount_path: /mlcube_io4/statistics.yaml + type: file + run_args: + command: + - statistics + - --data_path=/mlcube_io0/ + - --labels_path=/mlcube_io1/ + - --parameters_file=/mlcube_io2/parameters.yaml + - --metadata_path=/mlcube_io3/ + - --output_path=/mlcube_io4/statistics.yaml diff --git a/examples/cc/rano/prep_parameters.yaml b/examples/cc/rano/prep_parameters.yaml index 1b71067fd..06532ab24 100644 --- a/examples/cc/rano/prep_parameters.yaml +++ b/examples/cc/rano/prep_parameters.yaml @@ -1,25 +1,21 @@ -{ - "seed": 2784, - "train_percent": 0.8, - "medperf_report_stages": - [ - "IDENTIFIED", - "VALIDATED", - "MISSING_MODALITIES", - "EXTRA_MODALITIES", - "VALIDATION_FAILED", - "CONVERTED_TO_NIfTI", - "NIfTI_CONVERSION_FAILED", - "BRAIN_EXTRACT_FINISHED", - "BRAIN_EXTRACT_FINISHED", - "TUMOR_EXTRACT_FAILED", - "MANUAL_REVIEW_COMPLETE", - "MANUAL_REVIEW_REQUIRED", - "MULTIPLE_ANNOTATIONS_ERROR", - "COMPARISON_COMPLETE", - "EXACT_MATCH_IDENTIFIED", - "ANNOTATION_COMPARISON_FAILED", - "ANNOTATION_CONFIRMED", - "DONE", - ], -} +medperf_report_stages: +- IDENTIFIED +- VALIDATED +- MISSING_MODALITIES +- EXTRA_MODALITIES +- VALIDATION_FAILED +- CONVERTED_TO_NIfTI +- NIfTI_CONVERSION_FAILED +- BRAIN_EXTRACT_FINISHED +- BRAIN_EXTRACT_FINISHED +- TUMOR_EXTRACT_FAILED +- MANUAL_REVIEW_COMPLETE +- MANUAL_REVIEW_REQUIRED +- MULTIPLE_ANNOTATIONS_ERROR +- COMPARISON_COMPLETE +- EXACT_MATCH_IDENTIFIED +- ANNOTATION_COMPARISON_FAILED +- ANNOTATION_CONFIRMED +- DONE +seed: 2784 +train_percent: 0.8 From c2523f12d36ec074bee525a8fceadfc9e55d519b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 12 Mar 2026 05:03:13 +0000 Subject: [PATCH 32/72] don't require buckets to be public --- cli/medperf/asset_management/asset_check.py | 17 ++++---------- .../asset_management/asset_policy_manager.py | 22 ++++++++++++++++++- cli/medperf/asset_management/checks_utils.py | 10 +++++++++ cli/medperf/asset_management/gcp_utils.py | 21 ++++++++++++++++++ examples/cc/README.md | 18 +++++++-------- examples/cc/base_image/src/assets/gcp/key.py | 2 +- .../cc/base_image/src/assets/gcp/storage.py | 6 +++-- 7 files changed, 70 insertions(+), 26 deletions(-) diff --git a/cli/medperf/asset_management/asset_check.py b/cli/medperf/asset_management/asset_check.py index 5b6f83a6e..c4c662059 100644 --- a/cli/medperf/asset_management/asset_check.py +++ b/cli/medperf/asset_management/asset_check.py @@ -4,7 +4,6 @@ check_user_role_on_kms_key, check_user_role_on_wip, ) -from google.auth.credentials import AnonymousCredentials def verify_asset_owner_setup(bucket_name, kms_key_resource, wip_resource): @@ -13,7 +12,7 @@ def verify_asset_owner_setup(bucket_name, kms_key_resource, wip_resource): "user", base_creds, bucket_name, - "roles/storage.objectAdmin", + "roles/storage.admin", ) if result: return False, result @@ -33,17 +32,6 @@ def verify_asset_owner_setup(bucket_name, kms_key_resource, wip_resource): "roles/cloudkms.admin", ) - if result: - return False, result - - anon_creds = AnonymousCredentials() - - result = check_user_role_on_bucket( - "anonymous user", - anon_creds, - bucket_name, - "roles/storage.objectViewer", - ) if result: return False, result @@ -53,4 +41,7 @@ def verify_asset_owner_setup(bucket_name, kms_key_resource, wip_resource): "roles/iam.workloadIdentityPoolAdmin", ) + if result: + return False, result + return True, "" diff --git a/cli/medperf/asset_management/asset_policy_manager.py b/cli/medperf/asset_management/asset_policy_manager.py index 89f7423a1..120eece9c 100644 --- a/cli/medperf/asset_management/asset_policy_manager.py +++ b/cli/medperf/asset_management/asset_policy_manager.py @@ -5,6 +5,7 @@ upload_file_to_gcs, encrypt_with_kms_key, set_kms_iam_policy, + set_gcs_iam_policy, update_workload_identity_pool_oidc_provider, ) @@ -84,7 +85,7 @@ def __update_wip_oidc_provider( self.config, attribute_mapping, attribute_condition ) - def __bind_kms_decrypter_role( + def __get_principal_set( self, permitted_workloads: list[CCWorkloadID], for_model: bool = False ): principal_set = ( @@ -97,12 +98,28 @@ def __bind_kms_decrypter_role( workload_id = workload.id_for_model if for_model else workload.id principal_set_list.append(principal_set + workload_id) + return principal_set_list + + def __bind_kms_decrypter_role( + self, permitted_workloads: list[CCWorkloadID], for_model: bool = False + ): + principal_set_list = self.__get_principal_set(permitted_workloads, for_model) set_kms_iam_policy( self.config, principal_set_list, "roles/cloudkms.cryptoKeyDecrypter", ) + def __bind_gcs_object_viewer_role( + self, permitted_workloads: list[CCWorkloadID], for_model: bool = False + ): + principal_set_list = self.__get_principal_set(permitted_workloads, for_model) + set_gcs_iam_policy( + self.config, + principal_set_list, + "roles/storage.objectViewer", + ) + def setup(self): pass @@ -113,3 +130,6 @@ def setup_policy(self, policy: dict[str, str]): def configure_policy(self, permitted_workloads: list[CCWorkloadID]): self.__bind_kms_decrypter_role(permitted_workloads, for_model=self.for_model) + self.__bind_gcs_object_viewer_role( + permitted_workloads, for_model=self.for_model + ) diff --git a/cli/medperf/asset_management/checks_utils.py b/cli/medperf/asset_management/checks_utils.py index 8cfd67535..90fd5fc8c 100644 --- a/cli/medperf/asset_management/checks_utils.py +++ b/cli/medperf/asset_management/checks_utils.py @@ -46,6 +46,16 @@ def get_role_permissions(role_name: str, resource: str): "storage.objects.get", "storage.objects.list", ] + if role_name == "roles/storage.admin": + # storage permissions are not fully reflected in the role definition, so we hardcode them here + return [ + "storage.objects.create", + "storage.objects.delete", + "storage.objects.get", + "storage.objects.list", + "storage.buckets.setIamPolicy", + "storage.buckets.getIamPolicy", + ] service = googleapiclient.discovery.build("iam", "v1") role = service.roles().get(name=role_name).execute() permissions = role.get("includedPermissions", []) diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/gcp_utils.py index 5271e3ca4..5125b593e 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/gcp_utils.py @@ -229,6 +229,27 @@ def check_gcs_file_exists( return bucket.blob(gcs_path).exists() +def set_gcs_iam_policy(config: GCPAssetConfig, members: list[str], role: str): + client = storage.Client() + # Get current policy + + policy = client.bucket(config.bucket).get_iam_policy() + + # remove current objectviewer roles + to_remove = [] + for binding in policy.bindings: + if binding.role == role: + to_remove.append(binding) + + for binding in to_remove: + policy.bindings.remove(binding) + + policy.bindings.append(policy_pb2.Binding(role=role, members=members)) + + # Set new policy + client.bucket(config.bucket).set_iam_policy(policy) + + # run def run_workload( config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: str diff --git a/examples/cc/README.md b/examples/cc/README.md index e3ea78640..11321fe57 100644 --- a/examples/cc/README.md +++ b/examples/cc/README.md @@ -20,27 +20,27 @@ run `gcloud auth application-default login` ## asset owner - Create a bucket -- grant public access ("roles/storage.objectViewer") to the bucket +- grant the user ("roles/storage.admin") to the bucket - grant the user write access ("roles/storage.objectAdmin") to the bucket - create a keyring - - select region + - select region - create a key - - software (default) + - software (default) - grant the user "roles/cloudkms.cryptoKeyEncrypter" for the key - grant the user "roles/cloudkms.admin" for the key - create a workload identity pool - - add name and description + - add name and description - - add OIDC provider + - add OIDC provider "--issuer-uri=", "--allowed-audiences=", - - select name and ID to be attestation-verifier - - add the following as the google subject: - - "gcpcs::"+assertion.submods.container.image_digest+"::"+assertion.submods.gce.project_number+"::"+assertion.submods.gce.instance_id + - select name and ID to be attestation-verifier + - add the following as the google subject: + - "gcpcs::"+assertion.submods.container.image_digest+"::"+assertion.submods.gce.project_number+"::"+assertion.submods.gce.instance_id - - click create/save + - click create/save - grant user update permissions for the wip: diff --git a/examples/cc/base_image/src/assets/gcp/key.py b/examples/cc/base_image/src/assets/gcp/key.py index 21200763d..9ca52a46e 100644 --- a/examples/cc/base_image/src/assets/gcp/key.py +++ b/examples/cc/base_image/src/assets/gcp/key.py @@ -33,7 +33,7 @@ def __decrypt_bytes(self, encrypted_data: bytes) -> bytes: def initialize(self) -> None: creds = get_credentials(self.wippro) self.kms_client = kms.KeyManagementServiceClient(credentials=creds) - self.storage_client = storage.Client() + self.storage_client = storage.Client(credentials=creds) def get_key(self, output_path: str) -> None: encrypted_key = self.__get_encrypted_key() diff --git a/examples/cc/base_image/src/assets/gcp/storage.py b/examples/cc/base_image/src/assets/gcp/storage.py index 9ec00ef80..c87a6e001 100644 --- a/examples/cc/base_image/src/assets/gcp/storage.py +++ b/examples/cc/base_image/src/assets/gcp/storage.py @@ -1,17 +1,19 @@ from google.cloud import storage -from .utils import GCPAssetConfig +from .utils import GCPAssetConfig, get_credentials class GCPStorage: def __init__(self, asset_config_dict: dict): asset_config = GCPAssetConfig(**asset_config_dict) self.bucket_name = asset_config.bucket + self.wippro = asset_config.full_wip_name self.asset_path = asset_config.encrypted_asset_bucket_file self.storage_client = None def initialize(self) -> None: - self.storage_client = storage.Client() + creds = get_credentials(self.wippro) + self.storage_client = storage.Client(credentials=creds) def get_asset(self, output_path: str) -> None: bucket = self.storage_client.bucket(self.bucket_name) From 232e9702f40f849816a0d5c0a839d25e3a5437d1 Mon Sep 17 00:00:00 2001 From: mhmdk0 Date: Thu, 12 Mar 2026 23:17:27 +0200 Subject: [PATCH 33/72] fix cc design --- cli/medperf/web_ui/static/js/cc.js | 20 +++++---- .../templates/dataset/dataset_detail.html | 41 ++++++++++++------- .../templates/macros/cc_asset_macro.html | 6 +-- .../web_ui/templates/model/model_detail.html | 41 ++++++++++++------- 4 files changed, 69 insertions(+), 39 deletions(-) diff --git a/cli/medperf/web_ui/static/js/cc.js b/cli/medperf/web_ui/static/js/cc.js index 5309e3826..917941cea 100644 --- a/cli/medperf/web_ui/static/js/cc.js +++ b/cli/medperf/web_ui/static/js/cc.js @@ -6,7 +6,6 @@ const fields = [ "cc-key_name", "cc-key_location", "cc-wip", - "require-cc" ]; function onCCEditRequestSuccess(response){ @@ -36,11 +35,16 @@ function onCCPolicyRequestSuccess(response){ function checkForCCEditChanges() { - // const hasChanges = fields.some(field => { - // return $(`#${field}`).val() !== window.defaultCCConfig[field]; - // }); - // TODO - hasChanges = true; + const preferences = window.ccPreferences || {}; + const configured = preferences.can_apply; + const defaults = preferences.defaults || {}; + const requireCCChanged = ($("#require-cc").is(":checked") !== configured) + var hasChanges = fields.some(field => { + let currentValue = $(`#${field}`).val(); + let defaultValue = defaults[field] || ""; + return currentValue !== defaultValue; + }); + hasChanges = hasChanges || requireCCChanged; $('#apply-cc-asset-btn').prop('disabled', !hasChanges); } @@ -85,14 +89,14 @@ async function editCCConfig(editCCConfigBtn) { } -$(document).ready(() => { +$(document).ready(() => { const checkbox = $("#require-cc"); checkbox.on("change", () => { $("#edit-cc-asset-fields").toggle(checkbox.is(":checked")); }); $("#edit-cc-asset-fields").toggle(checkbox.is(":checked")); - fields.forEach(field => $(`#${field}`).on('input', checkForCCEditChanges)); + fields + ["require-cc"].forEach(field => $(`#${field}`).on('keyup, change', checkForCCEditChanges)); checkForCCEditChanges(); $("#apply-cc-asset-btn").on("click", (e) => { diff --git a/cli/medperf/web_ui/templates/dataset/dataset_detail.html b/cli/medperf/web_ui/templates/dataset/dataset_detail.html index d5e197c6f..7aef55cbf 100644 --- a/cli/medperf/web_ui/templates/dataset/dataset_detail.html +++ b/cli/medperf/web_ui/templates/dataset/dataset_detail.html @@ -200,6 +200,27 @@
Details
Step 3: Make a request to the benchmark owner to associate your dataset with the benchmark
+
+
+
Confidential Computing Preferences
+
+
+ {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, dataset.id, "dataset", task_running) }} +
+ {% if cc_configured %} +
+

Confidential Computing is configured, you can sync the policy.

+ +
+ {% endif %} +
{% endif %} @@ -426,20 +447,6 @@

{% endfor %} {% endif %} -{% if is_owner %} - {% if cc_configured %} - - {% endif %} - {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, dataset.id, "dataset", task_running) }} -{% endif %} - {% include "partials/panel_container.html" %} {% include "partials/text_content_container.html" %} {% include "partials/yaml_container.html" %} @@ -448,6 +455,12 @@

{% endblock detail_panel %} {% block extra_js %} + diff --git a/cli/medperf/web_ui/templates/macros/cc_asset_macro.html b/cli/medperf/web_ui/templates/macros/cc_asset_macro.html index a89ebcbc1..f623ae1e6 100644 --- a/cli/medperf/web_ui/templates/macros/cc_asset_macro.html +++ b/cli/medperf/web_ui/templates/macros/cc_asset_macro.html @@ -10,10 +10,10 @@ ("key_location", "GCP KMS Key Location"), ("wip", "GCP Workload Identity Pool Name"), ] %} -
+
-
+
@@ -34,7 +34,7 @@
+ data-entity-type="{{ entity_type }}" class="btn btn-primary mb-3" disabled>Apply Changes
diff --git a/cli/medperf/web_ui/templates/model/model_detail.html b/cli/medperf/web_ui/templates/model/model_detail.html index a8b188145..b6ef924be 100644 --- a/cli/medperf/web_ui/templates/model/model_detail.html +++ b/cli/medperf/web_ui/templates/model/model_detail.html @@ -144,6 +144,27 @@
Details
Make a request to the benchmark owner to associate your model with the benchmark
+
+
+
Confidential Computing Preferences
+
+
+ {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, entity.id, "model", task_running) }} +
+ {% if cc_configured %} +
+

Confidential Computing is configured, you can sync the policy.

+ +
+ {% endif %} +
{% endif %}
@@ -158,20 +179,6 @@

Associated Benchmarks

{% endif %} -{% if is_owner %} - {% if cc_configured %} - - {% endif %} - {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, entity.id, "model", task_running) }} -{% endif %} - {% include "partials/panel_container.html" %} {% include "partials/text_content_container.html" %} {% include "partials/yaml_container.html" %} @@ -179,6 +186,12 @@

Associated Benchmarks

{% endblock detail_panel %} {% block extra_js %} + From 0e4aaa14370cbe969d19f761366d180bb5373355 Mon Sep 17 00:00:00 2001 From: mhmdk0 Date: Thu, 12 Mar 2026 23:17:55 +0200 Subject: [PATCH 34/72] fix operator cc design --- cli/medperf/web_ui/static/js/cc_operator.js | 20 +++++++++++-------- .../templates/macros/cc_operator_macro.html | 9 +++++++-- cli/medperf/web_ui/templates/settings.html | 10 ++++++++-- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/cli/medperf/web_ui/static/js/cc_operator.js b/cli/medperf/web_ui/static/js/cc_operator.js index d189a513a..4e859a214 100644 --- a/cli/medperf/web_ui/static/js/cc_operator.js +++ b/cli/medperf/web_ui/static/js/cc_operator.js @@ -7,16 +7,20 @@ const fields = [ "operator-boot_disk_size", "operator-gpus", "operator-machine_type", - "require-cc-operator" ]; function checkForCCEditChanges() { - // const hasChanges = fields.some(field => { - // return $(`#${field}`).val() !== window.defaultCCConfig[field]; - // }); - // TODO - hasChanges = true; - $('#apply-cc-operator-btn').prop('disabled', !hasChanges); + const preferences = window.ccPreferences || {}; + const configured = preferences.can_apply; + const defaults = preferences.defaults || {}; + const requireCCChanged = ($("#require-cc-operator").is(":checked") !== configured) + var hasChanges = fields.some(field => { + let currentValue = $(`#${field}`).val(); + let defaultValue = defaults[field] || ""; + return currentValue !== defaultValue; + }); + hasChanges = hasChanges || requireCCChanged; + $('#apply-cc-operator-btn').prop('disabled', !hasChanges); } function editCCConfig(editCCConfigBtn) { @@ -54,7 +58,7 @@ $(document).ready(() => { }); $("#edit-cc-operator-fields").toggle(checkbox.is(":checked")); - fields.forEach(field => $(`#${field}`).on('input', checkForCCEditChanges)); + fields + ["require-cc-operator"].forEach(field => $(`#${field}`).on('keyup, change', checkForCCEditChanges)); checkForCCEditChanges(); $("#apply-cc-operator-btn").on("click", (e) => { diff --git a/cli/medperf/web_ui/templates/macros/cc_operator_macro.html b/cli/medperf/web_ui/templates/macros/cc_operator_macro.html index a78655e8e..de07165f0 100644 --- a/cli/medperf/web_ui/templates/macros/cc_operator_macro.html +++ b/cli/medperf/web_ui/templates/macros/cc_operator_macro.html @@ -11,10 +11,15 @@ ("boot_disk_size", "VM Boot Disk Size (GB)"), ("gpus", "VM Number of GPUs"), ] %} -
+
-
+
+

+ Confidential Computing Operator Settings +

+
+
diff --git a/cli/medperf/web_ui/templates/settings.html b/cli/medperf/web_ui/templates/settings.html index c2610d480..664def83a 100644 --- a/cli/medperf/web_ui/templates/settings.html +++ b/cli/medperf/web_ui/templates/settings.html @@ -166,14 +166,20 @@
Certificate Exists
{% include "partials/prompt_container.html" %}
- +
{% if logged_in %} {{ cc_operator_macro.cc_operator(cc_config_defaults, cc_configured, task_running=task_running) }} {% endif %} - +
{% endblock content %} {% block extra_js %} +