diff --git a/cli/medperf/asset_management/gcp_utils.py b/cli/medperf/asset_management/__gcp_util_archive.py similarity index 56% rename from cli/medperf/asset_management/gcp_utils.py rename to cli/medperf/asset_management/__gcp_util_archive.py index 37f80c4d5..50b12567f 100644 --- a/cli/medperf/asset_management/gcp_utils.py +++ b/cli/medperf/asset_management/__gcp_util_archive.py @@ -1,118 +1,14 @@ -"""Utility functions for GCP operations.""" - import logging from typing import Union from medperf.exceptions import ExecutionError from medperf.utils import run_command -from dataclasses import dataclass -from google.cloud import kms -from google.iam.v1 import policy_pb2 + from google.cloud import storage from google.cloud.exceptions import Conflict -from google.cloud import compute_v1 -import time -from colorama import Fore, Style -import medperf.config as medperf_config - -GCP_EXEC = "gcloud" - - -@dataclass -class CCWorkloadID: - data_hash: str - model_hash: str - script_hash: str - result_collector_hash: str - data_id: int - model_id: int - script_id: int - - @property - def id(self): - return "::".join( - [ - self.script_hash, - self.data_hash, - self.model_hash, - self.result_collector_hash, - ] - ) - @property - def human_readable_id(self): - return f"d{self.data_id}-m{self.model_id}-s{self.script_id}" - - @property - def vm_template_name(self): - return f"{self.human_readable_id}-vm-template" - - @property - def instance_group_name(self): - return f"{self.human_readable_id}-vm-instance-group" - - @property - def resize_request_name(self): - return f"{self.human_readable_id}-vm-instance-group-resize-request" - - @property - def vm_name(self): - return f"{self.human_readable_id}-cvm" - - @property - def results_path(self): - return f"{self.human_readable_id}/output" - - @property - def results_encryption_key_path(self): - return f"{self.human_readable_id}/encryption_key" - - -@dataclass -class GCPOperatorConfig: - project_id: str - service_account_name: str - account: str - bucket: str - machine_type: str - boot_disk_size: str - cc_type: str - min_cpu_platform: str - vm_zone: str - vm_network: str - logs_poll_frequency: int - gpu: bool - run_duration: str - - @property - def service_account_email(self): - return f"{self.service_account_name}@{self.project_id}.iam.gserviceaccount.com" - - -@dataclass -class GCPAssetConfig: - project_id: str - project_number: str - account: str - bucket: str - encrypted_asset_bucket_file: str - encrypted_key_bucket_file: str - keyring_name: str - key_name: str - wip: str - - @property - def full_key_name(self) -> str: - return ( - f"projects/{self.project_id}/locations/global/" - f"keyRings/{self.keyring_name}/cryptoKeys/{self.key_name}" - ) +from .gcp_utils import GCPOperatorConfig, GCPAssetConfig, CCWorkloadID - @property - def full_wip_name(self) -> str: - return ( - f"projects/{self.project_number}/locations/global/" - f"workloadIdentityPools/{self.wip}/providers/attestation-verifier" - ) +GCP_EXEC = "gcloud" # IAM Service Account operations @@ -120,7 +16,6 @@ def create_service_account(config: GCPOperatorConfig): """Create service account for workload.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "iam", "service-accounts", "create", @@ -140,7 +35,6 @@ def add_service_account_iam_policy_binding( """Add IAM policy binding to service account.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "iam", "service-accounts", "add-iam-policy-binding", @@ -155,7 +49,6 @@ def add_project_iam_policy_binding(config: GCPOperatorConfig, member: str, role: """Add IAM policy binding to project.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "projects", "add-iam-policy-binding", config.project_id, @@ -170,12 +63,11 @@ def create_keyring(config: GCPAssetConfig): """Create KMS keyring.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "kms", "keyrings", "create", config.keyring_name, - "--location=global", + f"--location={config.key_location}", ] try: run_command(cmd) @@ -187,12 +79,11 @@ def create_kms_key(config: GCPAssetConfig): """Create KMS key.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "kms", "keys", "create", config.key_name, - "--location=global", + f"--location={config.key_location}", f"--keyring={config.keyring_name}", "--purpose=encryption", ] @@ -207,7 +98,6 @@ def create_kms_key(config: GCPAssetConfig): def add_kms_key_iam_policy_binding(config: GCPAssetConfig, member: str, role: str): cmd = [ GCP_EXEC, - f"--project={config.project_id}", "kms", "keys", "add-iam-policy-binding", @@ -218,47 +108,11 @@ def add_kms_key_iam_policy_binding(config: GCPAssetConfig, member: str, role: st run_command(cmd) -def set_kms_iam_policy(config: GCPAssetConfig, members: list[str], role: str): - client = kms.KeyManagementServiceClient() - # Get current policy - policy = client.get_iam_policy(request={"resource": config.full_key_name}) - - # remove current decryptor roles - to_remove = [] - for binding in policy.bindings: - if binding.role == role: - to_remove.append(binding) - - for binding in to_remove: - policy.bindings.remove(binding) - - policy.bindings.append(policy_pb2.Binding(role=role, members=members)) - # Set new policy - client.set_iam_policy(request={"resource": config.full_key_name, "policy": policy}) - - -def encrypt_with_kms_key( - config: GCPAssetConfig, plaintext_file: str, ciphertext_file: str -): - """Encrypt file using KMS key.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "kms", - "encrypt", - f"--ciphertext-file={ciphertext_file}", - f"--plaintext-file={plaintext_file}", - f"--key={config.full_key_name}", - ] - run_command(cmd) - - # Workload Identity Pool operations def create_workload_identity_pool(config: GCPAssetConfig): """Create workload identity pool.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "iam", "workload-identity-pools", "create", @@ -279,12 +133,11 @@ def create_workload_identity_pool_oidc_provider( """Create OIDC provider for workload identity pool.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "iam", "workload-identity-pools", "providers", "create-oidc", - "attestation-verifier", + config.wip_provider, "--location=global", f"--workload-identity-pool={config.wip}", "--issuer-uri=https://confidentialcomputing.googleapis.com/", @@ -299,25 +152,6 @@ def create_workload_identity_pool_oidc_provider( logging.debug( f"OIDC provider for workload identity pool {config.wip} may already exist. Error: {e}" ) - # try updating the provider if it already exists - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "iam", - "workload-identity-pools", - "providers", - "update-oidc", - "attestation-verifier", - "--location=global", - f"--workload-identity-pool={config.wip}", - f"--attribute-mapping={attribute_mapping}", - f"--attribute-condition={attribute_condition}", - ] - run_command(cmd) - logging.debug( - f"Updated OIDC provider for workload identity pool {config.wip}" - f" with new attribute mapping and condition." - ) # Storage operations @@ -336,43 +170,12 @@ def create_storage_bucket(config: Union[GCPAssetConfig, GCPOperatorConfig]): logging.debug(f"Bucket {config.bucket} already exists. Conflict: {e}") -def upload_file_to_gcs( - config: Union[GCPAssetConfig, GCPOperatorConfig], local_file: str, gcs_path: str -): - """Upload file to Google Cloud Storage.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "storage", - "cp", - local_file, - gcs_path, - ] - run_command(cmd) - - -def download_file_from_gcs( - config: Union[GCPAssetConfig, GCPOperatorConfig], gcs_path: str, local_file: str -): - """Download file from Google Cloud Storage.""" - cmd = [ - GCP_EXEC, - f"--project={config.project_id}", - "storage", - "cp", - gcs_path, - local_file, - ] - run_command(cmd) - - def add_bucket_iam_policy_binding( config: Union[GCPAssetConfig, GCPOperatorConfig], member: str, role: str ): """Add IAM policy binding to GCS bucket.""" cmd = [ GCP_EXEC, - f"--project={config.project_id}", "storage", "buckets", "add-iam-policy-binding", @@ -383,22 +186,23 @@ def add_bucket_iam_policy_binding( run_command(cmd) +# vm creation # run def run_workload( - config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: str + config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: dict[str, str] ): + metadata = "^~^" + "~".join([f"{key}={value}" for key, value in metadata.items()]) # note: machine type and cc type must conform somehow cmd = [ GCP_EXEC, - f"--project={config.project_id}", "compute", "instances", "create", - workload_config.vm_name, + config.vm_name, f"--confidential-compute-type={config.cc_type}", "--shielded-secure-boot", "--scopes=cloud-platform", - f"--boot-disk-size={config.boot_disk_size}", + f"--boot-disk-size={config.boot_disk_size}G", f"--zone={config.vm_zone}", f"--network={config.vm_network}", "--maintenance-policy=MIGRATE", @@ -414,14 +218,13 @@ def run_workload( def run_gpu_workload( - config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: str + config: GCPOperatorConfig, workload_config: CCWorkloadID, metadata: dict[str, str] ): # note: --image-family=confidential-space-preview-cgpu # note: boot disk size must be at least 30GB - + metadata = "^~^" + "~".join([f"{key}={value}" for key, value in metadata.items()]) cmd = [ GCP_EXEC, - f"--project={config.project_id}", "beta", "compute", "instance-templates", @@ -436,9 +239,9 @@ def run_gpu_workload( "--image-family=confidential-space-debug-preview-cgpu", f"--service-account={config.service_account_email}", "--scopes=cloud-platform", - f"--boot-disk-size={config.boot_disk_size}", + f"--boot-disk-size={config.boot_disk_size}G", "--reservation-affinity=none", - f"--max-run-duration={config.run_duration}", + f"--max-run-duration={config.run_duration}h", "--instance-termination-action=DELETE", f"--metadata={metadata}", ] @@ -452,7 +255,6 @@ def run_gpu_workload( cmd = [ GCP_EXEC, - f"--project={config.project_id}", "compute", "instance-groups", "managed", @@ -467,7 +269,6 @@ def run_gpu_workload( cmd = [ GCP_EXEC, - f"--project={config.project_id}", "compute", "instance-groups", "managed", @@ -481,32 +282,110 @@ def run_gpu_workload( run_command(cmd) -def wait_for_workload_completion( - config: GCPOperatorConfig, workload_config: CCWorkloadID +def update_workload_identity_pool_oidc_provider( + config: GCPAssetConfig, attribute_mapping: str, attribute_condition: str ): + cmd = [ + "gcloud", + "iam", + "workload-identity-pools", + "providers", + "update-oidc", + config.wip_provider, + "--location=global", + f"--workload-identity-pool={config.wip}", + f"--attribute-mapping={attribute_mapping}", + f"--attribute-condition={attribute_condition}", + ] + run_command(cmd) + logging.debug( + f"Updated OIDC provider for workload identity pool {config.wip}" + f" with new attribute mapping and condition." + ) + - client = compute_v1.InstancesClient() - project_id = config.project_id - zone = config.vm_zone - instance_name = workload_config.vm_name - next_start = 0 - while True: - instance = client.get(project=project_id, zone=zone, instance=instance_name) - status = instance.status - if status == "TERMINATED": - return "TERMINATED" - request = compute_v1.GetSerialPortOutputInstanceRequest( - project=project_id, - zone=zone, - instance=instance_name, - start=next_start, +""" +from google.auth import impersonated_credentials + +### checks + +def impersonate_service_account(base_creds, sa_email): + logging.debug(f"Impersonating service account: {sa_email}") + try: + return impersonated_credentials.Credentials( + source_credentials=base_creds, + target_principal=sa_email, + target_scopes=["https://www.googleapis.com/auth/cloud-platform"], ) - output = client.get_serial_port_output(request=request) - if output.contents: - next_start = output.next_ - medperf_config.ui.print_subprocess_logs( - f"{Fore.WHITE}{Style.DIM}{output.contents}{Style.RESET_ALL}" + except Exception as e: + logging.debug(f"Failed to impersonate service account {sa_email}: {e}") + return None + + +# --------------------------------------------------------------------------- +# Service Account roles +# --------------------------------------------------------------------------- + + +def check_sa_roles_for_project(sa_creds, project_id, role): + + logging.debug(f"Checking service account project permissions: {project_id}") + try: + permissions = get_role_permissions( + role, "//cloudresourcemanager.googleapis.com/projects/_" + ) + crm = googleapiclient.discovery.build( + "cloudresourcemanager", "v1", credentials=sa_creds, cache_discovery=False + ) + granted = ( + crm.projects() + .testIamPermissions(resource=project_id, body={"permissions": permissions}) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return ( + f"(Role {role}) Service account missing permissions: " + f"{missing} on project: {project_id}" ) - logging.debug(output.contents) + return None + except Exception as e: + logging.debug(f"check_sa_project_permissions exception: {e}") + return f"Failed to verify service account role on project: {project_id}" + + + +def check_user_role_on_vm(creds, project_id, vm_name, vm_zone, role): - time.sleep(config.logs_poll_frequency) + logging.debug(f"Checking if user has {role} role on VM: {vm_name}") + try: + permissions = get_role_permissions( + role, + "//compute.googleapis.com/projects/_/zones/_/instances/_", + ) + compute = googleapiclient.discovery.build( + "compute", "v1", credentials=creds, cache_discovery=False + ) + granted = ( + compute.instances() + .testIamPermissions( + project=project_id, + zone=vm_zone, + resource=vm_name, + body={"permissions": permissions}, + ) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return f"(Role {role}) user missing permissions: {missing} on VM: {vm_name}" + return None + except Exception as e: + logging.debug(f"check_user_role_on_vm exception: {e}") + return f"Failed to verify user role on VM: {vm_name}" + +""" diff --git a/cli/medperf/asset_management/asset_check.py b/cli/medperf/asset_management/asset_check.py new file mode 100644 index 000000000..09638d95d --- /dev/null +++ b/cli/medperf/asset_management/asset_check.py @@ -0,0 +1,42 @@ +from medperf.asset_management.gcp_utils import checks, get_user_credentials + + +def verify_asset_owner_setup(bucket_name, kms_key_resource, wip_resource): + base_creds = get_user_credentials() + result = checks.check_user_role_on_bucket( + "user", + base_creds, + bucket_name, + "roles/storage.admin", + ) + if result: + return False, result + + result = checks.check_user_role_on_kms_key( + base_creds, + kms_key_resource, + "roles/cloudkms.cryptoKeyEncrypter", + ) + + if result: + return False, result + + result = checks.check_user_role_on_kms_key( + base_creds, + kms_key_resource, + "roles/cloudkms.admin", + ) + + if result: + return False, result + + result = checks.check_user_role_on_wip( + base_creds, + wip_resource, + "roles/iam.workloadIdentityPoolAdmin", + ) + + if result: + return False, result + + return True, "" diff --git a/cli/medperf/asset_management/asset_management.py b/cli/medperf/asset_management/asset_management.py index 84ec0d26d..749c1e200 100644 --- a/cli/medperf/asset_management/asset_management.py +++ b/cli/medperf/asset_management/asset_management.py @@ -1,4 +1,8 @@ -from medperf.asset_management.gcp_utils import CCWorkloadID +from medperf.asset_management.gcp_utils import ( + CCWorkloadID, + GCPAssetConfig, + GCPOperatorConfig, +) from medperf.entities.dataset import Dataset from medperf.entities.model import Model from medperf.entities.user import User @@ -7,134 +11,117 @@ from medperf.asset_management.cc_operator import OperatorManager from medperf.utils import tar, generate_tmp_path import secrets -import os -from medperf import config +from medperf.exceptions import MedperfException +from medperf import config as medperf_config -def generate_encryption_key(encryption_key_file: str): - with open(encryption_key_file, "wb") as f: - pass - os.chmod(encryption_key_file, 0o700) - with open(encryption_key_file, "ab") as f: - f.write(secrets.token_bytes(32)) +def generate_encryption_key(): + return secrets.token_bytes(32) + + +def validate_cc_config(cc_config: dict, asset_name_prefix: str): + if cc_config == {}: + return + + cc_config["encrypted_asset_bucket_file"] = asset_name_prefix + ".enc" + cc_config["encrypted_key_bucket_file"] = asset_name_prefix + "_key.enc" + + GCPAssetConfig(**cc_config) + + +def validate_cc_operator_config(cc_config: dict): + if cc_config == {}: + return + GCPOperatorConfig(**cc_config) def setup_dataset_for_cc(dataset: Dataset): + if not dataset.is_cc_configured(): + return cc_config = dataset.get_cc_config() cc_policy = dataset.get_cc_policy() - if not cc_config: - raise ValueError( - f"Dataset {dataset.id} does not have a configuration for confidential computing." - ) - if cc_policy is None: - raise ValueError( - f"Dataset {dataset.id} does not have a policy for confidential computing." - ) + __verify_cloud_environment(cc_config) + # create dataset asset + medperf_config.ui.text = "Compressing dataset" asset_path = generate_tmp_path() tar(asset_path, [dataset.data_path, dataset.labels_path]) - # create encryption key - encryption_key_folder = os.path.join( - config.cc_artifacts_dir, "dataset" + str(dataset.id) - ) - os.makedirs(encryption_key_folder, exist_ok=True) - encryption_key_file = os.path.join(encryption_key_folder, "encryption_key.bin") - generate_encryption_key(encryption_key_file) - - __setup_asset_for_cc(cc_config, cc_policy, asset_path, encryption_key_file) + __setup_asset_for_cc(cc_config, cc_policy, asset_path) def setup_model_for_cc(model: Model): + if not model.is_cc_configured(): + return cc_config = model.get_cc_config() cc_policy = model.get_cc_policy() - if not cc_config: - raise ValueError( - f"Model {model.id} does not have a configuration for confidential computing." - ) - if cc_policy is None: - raise ValueError( - f"Model {model.id} does not have a policy for confidential computing." - ) if model.type != "ASSET": - raise ValueError( + raise MedperfException( f"Model {model.id} is not a file-based asset and cannot be set up for confidential computing." ) - asset = model.asset_obj # create model asset asset_path = asset.get_archive_path() - # create encryption key - encryption_key_folder = os.path.join( - config.cc_artifacts_dir, "model" + str(model.id) - ) - os.makedirs(encryption_key_folder, exist_ok=True) - encryption_key_file = os.path.join(encryption_key_folder, "encryption_key.bin") - generate_encryption_key(encryption_key_file) + __verify_cloud_environment(cc_config) + __setup_asset_for_cc(cc_config, cc_policy, asset_path, for_model=True) - __setup_asset_for_cc(cc_config, cc_policy, asset_path, encryption_key_file) + +def __verify_cloud_environment(cc_config: dict): + AssetStorageManager(cc_config, None, None).setup() def __setup_asset_for_cc( - cc_config: dict, cc_policy: dict, asset_path: str, encryption_key_file: str + cc_config: dict, + cc_policy: dict, + asset_path: str, + for_model: bool = False, ): - # asset storage setup - asset_storage_manager = AssetStorageManager( - cc_config, asset_path, encryption_key_file - ) - asset_storage_manager.setup() + # create encryption key + encryption_key = generate_encryption_key() + + asset_storage_manager = AssetStorageManager(cc_config, asset_path, encryption_key) + asset_policy_manager = AssetPolicyManager(cc_config, for_model=for_model) + + # storage asset_storage_manager.store_asset() # policy setup - asset_policy_manager = AssetPolicyManager(cc_config, encryption_key_file) - asset_policy_manager.setup() - asset_policy_manager.setup_policy(cc_policy) + asset_policy_manager.setup_policy(cc_policy, encryption_key) + del encryption_key def update_dataset_cc_policy(dataset: Dataset, permitted_workloads: list[CCWorkloadID]): - cc_config = dataset.get_cc_config() - if not cc_config: - raise ValueError( + if not dataset.is_cc_configured(): + raise MedperfException( f"Dataset {dataset.id} does not have a configuration for confidential computing." ) - encryption_key_folder = os.path.join( - config.cc_artifacts_dir, "dataset" + str(dataset.id) - ) - encryption_key_file = os.path.join(encryption_key_folder, "encryption_key.bin") - - asset_policy_manager = AssetPolicyManager(cc_config, encryption_key_file) + cc_config = dataset.get_cc_config() + asset_policy_manager = AssetPolicyManager(cc_config) asset_policy_manager.configure_policy(permitted_workloads) def update_model_cc_policy(model: Model, permitted_workloads: list[CCWorkloadID]): - cc_config = model.get_cc_config() - if not cc_config: - raise ValueError( + if not model.is_cc_configured(): + raise MedperfException( f"Model {model.id} does not have a configuration for confidential computing." ) + cc_config = model.get_cc_config() if model.type != "ASSET": - raise ValueError( + raise MedperfException( f"Model {model.id} is not a file-based asset and cannot be set up for confidential computing." ) - encryption_key_folder = os.path.join( - config.cc_artifacts_dir, "model" + str(model.id) - ) - encryption_key_file = os.path.join(encryption_key_folder, "encryption_key.bin") - - asset_policy_manager = AssetPolicyManager(cc_config, encryption_key_file) + asset_policy_manager = AssetPolicyManager(cc_config, for_model=True) asset_policy_manager.configure_policy(permitted_workloads) def setup_operator(user: User): - cc_config = user.get_cc_config() - if not cc_config: - raise ValueError( - "User does not have a configuration for confidential computing." - ) + if not user.is_cc_configured(): + return + cc_config = user.get_cc_config() operator_manager = OperatorManager(cc_config) operator_manager.setup() @@ -156,6 +143,10 @@ def run_workload( model_cc_config, result_collector_public_key, ) + + +def wait_for_workload(workload: CCWorkloadID, operator_cc_config: dict): + operator_manager = OperatorManager(operator_cc_config) operator_manager.wait_for_workload_completion(workload) @@ -168,3 +159,8 @@ def download_results( operator_manager = OperatorManager(operator_cc_config) operator_manager.download_results(workload, private_key_bytes, results_path) + + +def workload_results_exists(operator_cc_config: dict, workload: CCWorkloadID) -> bool: + operator_manager = OperatorManager(operator_cc_config) + return operator_manager.results_exist(workload) diff --git a/cli/medperf/asset_management/asset_policy_manager.py b/cli/medperf/asset_management/asset_policy_manager.py index 0d835580b..190644dc0 100644 --- a/cli/medperf/asset_management/asset_policy_manager.py +++ b/cli/medperf/asset_management/asset_policy_manager.py @@ -1,65 +1,73 @@ -from medperf.utils import generate_tmp_path -from medperf.asset_management import gcp_utils - - -class AssetPolicyManager: - def __init__(self, config: dict, encryption_key_file: str): - self.config = gcp_utils.GCPAssetConfig(**config) - self.encryption_key_file = encryption_key_file - - def __create_keyring(self): - gcp_utils.create_keyring(self.config) - - def __create_key(self): - gcp_utils.create_kms_key(self.config) +from medperf.exceptions import MedperfException +from medperf.asset_management.gcp_utils import ( + GCPAssetConfig, + CCWorkloadID, + upload_string_to_gcs, + encrypt_with_kms_key, + set_kms_iam_policy, + set_gcs_iam_policy, + update_workload_identity_pool_oidc_provider, +) +from medperf import config as medperf_config + + +def get_workload_id_scheme(for_model: bool = False): + if for_model: + return ( + 'assertion.submods.container.image_digest+"::"+' + "assertion.submods.container.env_override.EXPECTED_MODEL_HASH" + ) - def __add_key_iam_binding(self): - gcp_utils.add_kms_key_iam_policy_binding( - self.config, - f"user:{self.config.account}", - "roles/cloudkms.cryptoKeyEncrypter", + else: + return ( + 'assertion.submods.container.image_digest+"::"+' + 'assertion.submods.container.env_override.EXPECTED_DATA_HASH+"::"+' + 'assertion.submods.container.env_override.EXPECTED_MODEL_HASH+"::"+' + "assertion.submods.container.env_override.EXPECTED_RESULT_COLLECTOR_HASH" ) - def __create_workload_identity_pool(self): - gcp_utils.create_workload_identity_pool(self.config) - def __encrypt_key(self): - tmp_encrypted_key_path = generate_tmp_path() +class AssetPolicyManager: + def __init__(self, config: dict, for_model: bool = False): + self.config = GCPAssetConfig(**config) + self.for_model = for_model - gcp_utils.encrypt_with_kms_key( - self.config, self.encryption_key_file, tmp_encrypted_key_path - ) - return tmp_encrypted_key_path + def __encrypt_key(self, encryption_key: bytes): + encrypted_key = encrypt_with_kms_key(self.config, encryption_key) + return encrypted_key - def __upload_encrypted_key(self, tmp_encrypted_key_path): - gcp_utils.upload_file_to_gcs( + def __upload_encrypted_key(self, encrypted_key): + upload_string_to_gcs( self.config, - tmp_encrypted_key_path, - f"gs://{self.config.bucket}/{self.config.encrypted_key_bucket_file}", + encrypted_key, + self.config.encrypted_key_bucket_file, ) - def __create_wip_oidc_provider(self, policy: dict[str, str]): + def __update_wip_oidc_provider( + self, policy: dict[str, str], for_model: bool = False + ): # IMPORTANT: https://docs.cloud.google.com/confidential-computing/ # confidential-space/docs/create-grant-access-confidential-resources#attestation-assertions google_subject_attr = ( - 'google.subject="gcpcs::"' + '"gcpcs::"' '+assertion.submods.container.image_digest+"::"' '+assertion.submods.gce.project_number+"::"' "+assertion.submods.gce.instance_id" ) - workload_uid_attr = ( - "attribute.workload_uid=" - 'assertion.submods.container.image_digest+"::"+' - 'assertion.submods.container.env_override.EXPECTED_DATA_HASH+"::"+' - 'assertion.submods.container.env_override.EXPECTED_MODEL_HASH+"::"+' - "assertion.submods.container.env_override.EXPECTED_RESULT_COLLECTOR_HASH" - ) - - attribute_mapping = google_subject_attr + "," + workload_uid_attr + workload_uid_attr = get_workload_id_scheme(for_model=for_model) + attribute_mapping = { + "google.subject": google_subject_attr, + "attribute.workload_uid": workload_uid_attr, + } attribute_condition = 'assertion.swname == "CONFIDENTIAL_SPACE"' - attribute_condition += ( - " && 'STABLE' in assertion.submods.confidential_space.support_attributes" + + gpu_mode = 'assertion.submods.nvidia_gpu.cc_mode == "ON"' + stable_image = ( + "'STABLE' in assertion.submods.confidential_space.support_attributes" ) + # NOTE: currently it seems that gpu mode is not stable + attribute_condition += f" && ({gpu_mode} || {stable_image})" + if "location" in policy: location_condition = f'assertion.submods.gce.zone == "{policy["location"]}"' attribute_condition += f" && {location_condition}" @@ -68,18 +76,17 @@ def __create_wip_oidc_provider(self, policy: dict[str, str]): hardware_condition = f'assertion.hwmodel == "{policy["hardware"]}"' attribute_condition += f" && {hardware_condition}" - if "gpu_cc_mode" in policy: - gpu_cc_mode_condition = ( - f'assertion.submods.nvidia_gpu.cc_mode == "{policy["gpu_cc_mode"]}"' + try: + update_workload_identity_pool_oidc_provider( + self.config, attribute_mapping, attribute_condition + ) + except Exception as e: + raise MedperfException( + f"Failed to update workload identity pool OIDC provider: {e}" ) - attribute_condition += f" && {gpu_cc_mode_condition}" - - gcp_utils.create_workload_identity_pool_oidc_provider( - self.config, attribute_mapping, attribute_condition - ) - def __bind_kms_decrypter_role( - self, permitted_workloads: list[gcp_utils.CCWorkloadID] + def __get_principal_set( + self, permitted_workloads: list[CCWorkloadID], for_model: bool = False ): principal_set = ( f"principalSet://iam.googleapis.com/projects/{self.config.project_number}/" @@ -88,24 +95,44 @@ def __bind_kms_decrypter_role( principal_set_list = [] for workload in permitted_workloads: - principal_set_list.append(principal_set + workload.id) + workload_id = workload.id_for_model if for_model else workload.id + principal_set_list.append(principal_set + workload_id) - gcp_utils.set_kms_iam_policy( + return principal_set_list + + def __bind_kms_decrypter_role( + self, permitted_workloads: list[CCWorkloadID], for_model: bool = False + ): + principal_set_list = self.__get_principal_set(permitted_workloads, for_model) + set_kms_iam_policy( self.config, principal_set_list, "roles/cloudkms.cryptoKeyDecrypter", ) + def __bind_gcs_object_viewer_role( + self, permitted_workloads: list[CCWorkloadID], for_model: bool = False + ): + principal_set_list = self.__get_principal_set(permitted_workloads, for_model) + set_gcs_iam_policy( + self.config, + principal_set_list, + "roles/storage.objectViewer", + ) + def setup(self): - self.__create_keyring() - self.__create_key() - self.__add_key_iam_binding() - self.__create_workload_identity_pool() - - def setup_policy(self, policy: dict[str, str]): - tmp_encrypted_key_path = self.__encrypt_key() - self.__upload_encrypted_key(tmp_encrypted_key_path) - self.__create_wip_oidc_provider(policy) - - def configure_policy(self, permitted_workloads: list[gcp_utils.CCWorkloadID]): - self.__bind_kms_decrypter_role(permitted_workloads) + pass + + def setup_policy(self, policy: dict[str, str], encryption_key: bytes): + medperf_config.ui.text = "Encrypting Key using GCP KMS" + encrypted_key = self.__encrypt_key(encryption_key) + medperf_config.ui.text = "Uploading Encrypted Key to GCP bucket" + self.__upload_encrypted_key(encrypted_key) + medperf_config.ui.text = "Setting up Workload Identity Pool" + self.__update_wip_oidc_provider(policy, for_model=self.for_model) + + def configure_policy(self, permitted_workloads: list[CCWorkloadID]): + self.__bind_kms_decrypter_role(permitted_workloads, for_model=self.for_model) + self.__bind_gcs_object_viewer_role( + permitted_workloads, for_model=self.for_model + ) diff --git a/cli/medperf/asset_management/asset_storage_manager.py b/cli/medperf/asset_management/asset_storage_manager.py index bc8ff1d9e..f772a3217 100644 --- a/cli/medperf/asset_management/asset_storage_manager.py +++ b/cli/medperf/asset_management/asset_storage_manager.py @@ -1,43 +1,53 @@ -from medperf.utils import generate_tmp_path, get_file_hash +from medperf.utils import ( + generate_tmp_path, + tmp_path_for_cc_asset_key, + secure_write_to_file, + get_file_hash, + remove_path, +) from medperf.encryption import SymmetricEncryption -from medperf.asset_management import gcp_utils +from medperf.asset_management.gcp_utils import GCPAssetConfig, upload_file_to_gcs +from medperf.asset_management.asset_check import verify_asset_owner_setup +from medperf.exceptions import MedperfException +from medperf import config as medperf_config class AssetStorageManager: - def __init__(self, config: dict, asset_path: str, encryption_key_file: str): - self.config = gcp_utils.GCPAssetConfig(**config) + def __init__(self, config: dict, asset_path: str, encryption_key: bytes): + self.config = GCPAssetConfig(**config) self.asset_path = asset_path - self.encryption_key_file = encryption_key_file - - def __create_bucket(self): - gcp_utils.create_storage_bucket(self.config) + self.encryption_key = encryption_key def __encrypt_asset(self): tmp_encrypted_asset_path = generate_tmp_path() + encryption_key_file = tmp_path_for_cc_asset_key() + secure_write_to_file(encryption_key_file, self.encryption_key) SymmetricEncryption().encrypt_file( - self.asset_path, self.encryption_key_file, tmp_encrypted_asset_path + self.asset_path, encryption_key_file, tmp_encrypted_asset_path ) + remove_path(encryption_key_file, sensitive=True) asset_hash = get_file_hash(tmp_encrypted_asset_path) return tmp_encrypted_asset_path, asset_hash def __upload_encrypted_asset(self, tmp_encrypted_asset_path): - gcp_utils.upload_file_to_gcs( + upload_file_to_gcs( self.config, tmp_encrypted_asset_path, - f"gs://{self.config.bucket}/{self.config.encrypted_asset_bucket_file}", - ) - - def __grant_bucket_public_read_access(self): - gcp_utils.add_bucket_iam_policy_binding( - self.config, "allUsers", "roles/storage.objectViewer" + self.config.encrypted_asset_bucket_file, ) def setup(self): - self.__create_bucket() - self.__grant_bucket_public_read_access() + medperf_config.ui.text = "Verifying Cloud Environment" + success, message = verify_asset_owner_setup( + self.config.bucket, self.config.full_key_name, self.config.full_wip_name + ) + if not success: + raise MedperfException(f"Asset owner setup verification failed: {message}") def store_asset(self): + medperf_config.ui.text = "Encrypting data locally" tmp_encrypted_asset_path, asset_hash = self.__encrypt_asset() + medperf_config.ui.text = "Uploading Encrypted data to GCP bucket" self.__upload_encrypted_asset(tmp_encrypted_asset_path) return asset_hash diff --git a/cli/medperf/asset_management/cc_operator.py b/cli/medperf/asset_management/cc_operator.py index 6fc2cb793..23174f679 100644 --- a/cli/medperf/asset_management/cc_operator.py +++ b/cli/medperf/asset_management/cc_operator.py @@ -1,95 +1,44 @@ import json -from medperf.asset_management import gcp_utils -from medperf.utils import generate_tmp_path, untar +from medperf.asset_management.gcp_utils import ( + GCPOperatorConfig, + CCWorkloadID, + download_file_from_gcs, + download_string_from_gcs, + check_gcs_file_exists, + run_workload, + wait_for_workload_completion, +) +from medperf.asset_management.operator_check import verify_operator_setup +from medperf.exceptions import MedperfException, ExecutionError +from medperf.utils import ( + generate_tmp_path, + untar, + tmp_path_for_cc_asset_key, + secure_write_to_file, + remove_path, +) from medperf.encryption import SymmetricEncryption, AsymmetricEncryption - - -def create_service_account(config: gcp_utils.GCPOperatorConfig): - gcp_utils.create_service_account(config) - - -def allow_operator_to_use_service_account(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_service_account_iam_policy_binding( - config, - f"user:{config.account}", - "roles/iam.serviceAccountUser", - ) - - -def grant_confidential_computing_workload_user(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_project_iam_policy_binding( - config, - f"serviceAccount:{config.service_account_email}", - "roles/confidentialcomputing.workloadUser", - ) - - -def grant_logging_log_writer(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_project_iam_policy_binding( - config, - f"serviceAccount:{config.service_account_email}", - "roles/logging.logWriter", - ) - - -def run_workload( - config: gcp_utils.GCPOperatorConfig, - docker_image: str, - env_vars: dict, - workload: gcp_utils.CCWorkloadID, -): - # Build metadata string - metadata_parts = [ - f"tee-image-reference={docker_image}", - "tee-container-log-redirect=true", - ] - - # Add environment variables - for key, value in env_vars.items(): - metadata_parts.append(f"tee-env-{key}={value}") - - if config.gpu: - metadata_parts.append("tee-install-gpu-driver=true") - - metadata = "^~^" + "~".join(metadata_parts) - - if config.gpu: - gcp_utils.run_gpu_workload(config, workload, metadata) - else: - gcp_utils.run_workload(config, workload, metadata) - - -def grant_bucket_read_access(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_bucket_iam_policy_binding( - config, f"user:{config.account}", "roles/storage.objectViewer" - ) - - -def grant_bucket_write_access(config: gcp_utils.GCPOperatorConfig): - gcp_utils.add_bucket_iam_policy_binding( - config, - f"serviceAccount:{config.service_account_email}", - "roles/storage.objectAdmin", - ) +from colorama import Fore, Style +import medperf.config as medperf_config class OperatorManager: def __init__(self, config: dict): - self.config = gcp_utils.GCPOperatorConfig(**config) + self.config = GCPOperatorConfig(**config) def setup(self): """Set up complete operator infrastructure""" - create_service_account(self.config) - allow_operator_to_use_service_account(self.config) - grant_confidential_computing_workload_user(self.config) - grant_logging_log_writer(self.config) - grant_bucket_read_access(self.config) - grant_bucket_write_access(self.config) + success, message = verify_operator_setup( + self.config.service_account_email, self.config.bucket + ) + + if not success: + raise MedperfException(f"Operator setup verification failed: {message}") def run_workload( self, docker_image: str, - workload: gcp_utils.CCWorkloadID, + workload: CCWorkloadID, dataset_cc_config: dict, model_cc_config: dict, result_collector_public_key: str, @@ -115,47 +64,67 @@ def run_workload( "RESULT_COLLECTOR": result_collector_public_key, "EXPECTED_RESULT_COLLECTOR_HASH": workload.result_collector_hash, } - run_workload(self.config, docker_image, env_vars, workload) - - def wait_for_workload_completion(self, workload: gcp_utils.CCWorkloadID): - gcp_utils.wait_for_workload_completion(self.config, workload) + metadata = {} + metadata["tee-image-reference"] = docker_image + metadata["tee-container-log-redirect"] = "true" + + # Add environment variables + for key, value in env_vars.items(): + metadata[f"tee-env-{key}"] = value + + try: + run_workload(self.config, metadata) + except Exception: + raise ExecutionError( + "Failed to run workload: User lacks permissions or VM does not exist" + ) + + def wait_for_workload_completion(self, workload: CCWorkloadID): + for output in wait_for_workload_completion(self.config, workload): + medperf_config.ui.print_subprocess_logs( + f"{Fore.WHITE}{Style.DIM}{output}{Style.RESET_ALL}" + ) + + def results_exist(self, workload: CCWorkloadID): + results_exist = check_gcs_file_exists(self.config, workload.results_path) + if not results_exist: + return False + decryption_key_exists = check_gcs_file_exists( + self.config, workload.results_encryption_key_path + ) + return decryption_key_exists def download_results( self, - workload: gcp_utils.CCWorkloadID, + workload: CCWorkloadID, private_key_bytes: bytes, results_path: str, ): encrypted_results_path = generate_tmp_path() - key_path = generate_tmp_path() - gcp_utils.download_file_from_gcs( - self.config, - f"gs://{self.config.bucket}/{workload.results_path}", - encrypted_results_path, + download_file_from_gcs( + self.config, workload.results_path, encrypted_results_path ) - gcp_utils.download_file_from_gcs( - self.config, - f"gs://{self.config.bucket}/{workload.results_encryption_key_path}", - key_path, + encrypted_key = download_string_from_gcs( + self.config, workload.results_encryption_key_path ) - with open(key_path, "rb") as key_file: - encrypted_key = key_file.read() + medperf_config.ui.text = "Decrypting predictions" decryption_key = AsymmetricEncryption().decrypt( private_key_bytes, encrypted_key ) - tmp_key_path = generate_tmp_path() - with open(tmp_key_path, "wb") as tmp_key_file: - tmp_key_file.write(decryption_key) - results_archive_path = generate_tmp_path() + tmp_key_path = tmp_path_for_cc_asset_key() + secure_write_to_file(tmp_key_path, decryption_key) SymmetricEncryption().decrypt_file( encrypted_results_path, tmp_key_path, results_archive_path ) + remove_path(tmp_key_path, sensitive=True) + del decryption_key # Extract results + medperf_config.ui.text = "Uncompressing predictions" untar(results_archive_path, remove=True, extract_to=results_path) diff --git a/cli/medperf/asset_management/gcp_utils/__init__.py b/cli/medperf/asset_management/gcp_utils/__init__.py new file mode 100644 index 000000000..fcb456b39 --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/__init__.py @@ -0,0 +1,33 @@ +from .compute import run_workload, wait_for_workload_completion +from .kms import set_kms_iam_policy, encrypt_with_kms_key +from .storage import ( + upload_file_to_gcs, + upload_string_to_gcs, + download_file_from_gcs, + download_string_from_gcs, + check_gcs_file_exists, + set_gcs_iam_policy, +) +from .types import CCWorkloadID, GCPOperatorConfig, GCPAssetConfig +from .workload_identity import update_workload_identity_pool_oidc_provider +from . import checks +from .utils import get_user_credentials + +__all__ = [ + "run_workload", + "wait_for_workload_completion", + "set_kms_iam_policy", + "encrypt_with_kms_key", + "upload_file_to_gcs", + "upload_string_to_gcs", + "download_file_from_gcs", + "download_string_from_gcs", + "check_gcs_file_exists", + "set_gcs_iam_policy", + "update_workload_identity_pool_oidc_provider", + "CCWorkloadID", + "GCPOperatorConfig", + "GCPAssetConfig", + "checks", + "get_user_credentials", +] diff --git a/cli/medperf/asset_management/gcp_utils/checks.py b/cli/medperf/asset_management/gcp_utils/checks.py new file mode 100644 index 000000000..c45b21e4c --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/checks.py @@ -0,0 +1,212 @@ +import logging +import googleapiclient.discovery + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +def get_testable_permissions(resource): + if resource.startswith("//cloudresourcemanager"): + # don't find testable permissions for cloudresourcemanager + # Since they are a lot. + return { + "confidentialcomputing.challenges.create", + "confidentialcomputing.challenges.verify", + "confidentialcomputing.locations.get", + "confidentialcomputing.locations.list", + "logging.logEntries.create", + "logging.logEntries.route", + } + if "workloadIdentityPools" in resource: + # doesn't seem to have an api for this + return { + "iam.googleapis.com/workloadIdentityPoolProviders.update", + "iam.googleapis.com/workloadIdentityPoolProviders.get", + "iam.googleapis.com/workloadIdentityPools.get", + } + + iam = googleapiclient.discovery.build("iam", "v1") + resp = ( + iam.permissions() + .queryTestablePermissions(body={"fullResourceName": resource, "pageSize": 1000}) + .execute() + ) + return set(p["name"] for p in resp.get("permissions", [])) + + +def get_role_permissions(role_name: str, resource: str): + if role_name == "roles/storage.objectAdmin": + # storage permissions are not fully reflected in the role definition, so we hardcode them here + return [ + "storage.objects.create", + "storage.objects.delete", + "storage.objects.get", + "storage.objects.list", + ] + if role_name == "roles/storage.admin": + # storage permissions are not fully reflected in the role definition, so we hardcode them here + return [ + "storage.objects.create", + "storage.objects.delete", + "storage.objects.get", + "storage.objects.list", + "storage.buckets.setIamPolicy", + "storage.buckets.getIamPolicy", + ] + if role_name == "roles/compute.instanceAdmin.v1": + return [ + "compute.instances.setMetadata", + "compute.instances.start", + "compute.instances.get", + "compute.instances.getSerialPortOutput", + ] + service = googleapiclient.discovery.build("iam", "v1") + role = service.roles().get(name=role_name).execute() + permissions = role.get("includedPermissions", []) + testable = get_testable_permissions(resource) + return list(testable.intersection(permissions)) + + +# --------------------------------------------------------------------------- +# User roles +# --------------------------------------------------------------------------- +def check_user_role_on_service_account(base_creds, sa_email, role): + logging.debug(f"Checking if user has {role} role on {sa_email}") + try: + permissions = get_role_permissions( + role, "//iam.googleapis.com/projects/_/serviceAccounts/_" + ) + iam = googleapiclient.discovery.build( + "iam", "v1", credentials=base_creds, cache_discovery=False + ) + granted = ( + iam.projects() + .serviceAccounts() + .testIamPermissions( + resource=f"projects/-/serviceAccounts/{sa_email}", + body={"permissions": permissions}, + ) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return f"(Role {role}) User missing permissions: {missing} on service account: {sa_email}" + return None + except Exception as e: + logging.debug(f"check_user_role_on_service_account exception: {e}") + return f"Failed to verify user role on service account: {sa_email}" + + +def check_user_role_on_bucket(user_str, creds, bucket_name, role): + + logging.debug(f"Checking if {user_str} has {role} role on bucket: {bucket_name}") + try: + permissions = get_role_permissions( + role, "//storage.googleapis.com/projects/_/buckets/_" + ) + iam = googleapiclient.discovery.build( + "storage", "v1", credentials=creds, cache_discovery=False + ) + granted = ( + iam.buckets() + .testIamPermissions(bucket=bucket_name, permissions=permissions) + .execute() + ) + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + if missing: + logging.debug(f"Missing permissions: {missing}") + return f"(Role {role}) {user_str} missing permissions: {missing} on bucket: {bucket_name}" + return None + except Exception as e: + logging.debug(f"check_user_role_on_bucket exception: {e}") + return f"Failed to verify {user_str} role on bucket: {bucket_name}" + + +# --------------------------------------------------------------------------- +# KMS roles +# --------------------------------------------------------------------------- + + +def check_user_role_on_kms_key(base_creds, kms_key_resource, role): + logging.debug(f"Checking user role {role} on KMS key {kms_key_resource}") + + try: + kms = googleapiclient.discovery.build( + "cloudkms", "v1", credentials=base_creds, cache_discovery=False + ) + + permissions = get_role_permissions( + role, + "//cloudkms.googleapis.com/projects/_/locations/_/keyRings/_/cryptoKeys/_", + ) + + granted = ( + kms.projects() + .locations() + .keyRings() + .cryptoKeys() + .testIamPermissions( + resource=kms_key_resource, + body={"permissions": permissions}, + ) + .execute() + ) + + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + + if missing: + return f"(Role {role}) User missing permissions: {missing} on KMS key {kms_key_resource}" + + return None + + except Exception as e: + logging.debug(f"KMS permission check failed: {e}") + return f"Failed verifying user roles on KMS key {kms_key_resource}" + + +# --------------------------------------------------------------------------- +# WIP roles +# --------------------------------------------------------------------------- + + +def check_user_role_on_wip(creds, wip, role): + + logging.debug(f"Checking user role {role} on WIP {wip}") + + try: + iam = googleapiclient.discovery.build( + "iam", "v1", credentials=creds, cache_discovery=False + ) + + permissions = get_role_permissions( + role, "//iam.googleapis.com/projects/_/locations/_/workloadIdentityPools/_" + ) + + granted = ( + iam.projects() + .locations() + .workloadIdentityPools() + .testIamPermissions( + resource=wip, + body={"permissions": permissions}, + ) + .execute() + ) + + granted_permissions = granted.get("permissions", []) + missing = set(permissions) - set(granted_permissions) + + if missing: + return f"(Role {role}) User missing permissions: {missing} on WIP {wip}" + + return None + + except Exception as e: + logging.debug(f"WIP permission check failed: {e}") + return f"Failed verifying user roles on WIP {wip}" diff --git a/cli/medperf/asset_management/gcp_utils/compute.py b/cli/medperf/asset_management/gcp_utils/compute.py new file mode 100644 index 000000000..66a5c8162 --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/compute.py @@ -0,0 +1,119 @@ +import logging +from google.cloud import compute_v1 +import time +from .types import GCPOperatorConfig, CCWorkloadID + + +# taken and adapted from +# https://docs.cloud.google.com/compute/docs/samples/compute-start-instance#compute_start_instance-python +def __wait_for_extended_operation(operation, verbose_name, timeout: int = 300): + """ + Waits for the extended (long-running) operation to complete. + + If the operation is successful, it will return its result. + If the operation ends with an error, an exception will be raised. + If there were any warnings during the execution of the operation + they will be printed to sys.stderr. + + Args: + operation: a long-running operation you want to wait on. + verbose_name: (optional) a more verbose name of the operation, + used only during error and warning reporting. + timeout: how long (in seconds) to wait for operation to finish. + If None, wait indefinitely. + + Returns: + Whatever the operation.result() returns. + + Raises: + This method will raise the exception received from `operation.exception()` + or RuntimeError if there is no exception set, but there is an `error_code` + set for the `operation`. + + In case of an operation taking longer than `timeout` seconds to complete, + a `concurrent.futures.TimeoutError` will be raised. + """ + result = operation.result(timeout=timeout) + + if operation.error_code: + logging.debug( + f"Error during {verbose_name}: [Code: {operation.error_code}]: {operation.error_message}" + ) + raise operation.exception() or RuntimeError(operation.error_message) + + if operation.warnings: + logging.debug(f"Warnings during {verbose_name}:\n") + for warning in operation.warnings: + logging.debug( + f"WARNING from {verbose_name} - {warning.code}: {warning.message}" + ) + + return result + + +def run_workload(config: GCPOperatorConfig, metadata: dict[str, str]): + """Run workload on GCP.""" + client = compute_v1.InstancesClient() + project_id = config.project_id + zone = config.vm_zone + instance_name = config.vm_name + + instance = client.get(project=project_id, zone=zone, instance=instance_name) + has_gpu = len(instance.guest_accelerators) > 0 + if has_gpu: + metadata["tee-install-gpu-driver"] = "true" + metadata_items = [] + for key, value in metadata.items(): + metadata_items.append(compute_v1.Items(key=key, value=value)) + + metadata_resource = compute_v1.Metadata( + fingerprint=instance.metadata.fingerprint, + items=metadata_items, + ) + try: + operation = client.set_metadata( + project=project_id, + zone=zone, + instance=instance_name, + metadata_resource=metadata_resource, + ) + __wait_for_extended_operation(operation, "set vm metadata") + except Exception as e: + logging.error(f"Failed to set metadata for instance {instance_name}: {e}") + raise + + try: + operation = client.start(project=project_id, zone=zone, instance=instance_name) + __wait_for_extended_operation(operation, "start vm") + except Exception as e: + logging.error(f"Failed to start instance {instance_name}: {e}") + raise + + +def wait_for_workload_completion( + config: GCPOperatorConfig, workload_config: CCWorkloadID +): + + client = compute_v1.InstancesClient() + project_id = config.project_id + zone = config.vm_zone + instance_name = config.vm_name + next_start = 0 + while True: + instance = client.get(project=project_id, zone=zone, instance=instance_name) + status = instance.status + if status == "TERMINATED": + return + request = compute_v1.GetSerialPortOutputInstanceRequest( + project=project_id, + zone=zone, + instance=instance_name, + start=next_start, + ) + output = client.get_serial_port_output(request=request) + if output.contents: + next_start = output.next_ + logging.debug(output.contents) + yield output.contents + + time.sleep(config.logs_poll_frequency) diff --git a/cli/medperf/asset_management/gcp_utils/kms.py b/cli/medperf/asset_management/gcp_utils/kms.py new file mode 100644 index 000000000..72b5b1832 --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/kms.py @@ -0,0 +1,36 @@ +from google.cloud import kms_v1 as kms +from google.iam.v1 import policy_pb2 +from .types import GCPAssetConfig +import logging + + +def set_kms_iam_policy(config: GCPAssetConfig, members: list[str], role: str): + client = kms.KeyManagementServiceClient() + # Get current policy + policy = client.get_iam_policy(request={"resource": config.full_key_name}) + + # remove current decryptor roles + to_remove = [] + for binding in policy.bindings: + if binding.role == role: + to_remove.append(binding) + + for binding in to_remove: + policy.bindings.remove(binding) + + policy.bindings.append(policy_pb2.Binding(role=role, members=members)) + # Set new policy + client.set_iam_policy(request={"resource": config.full_key_name, "policy": policy}) + + +def encrypt_with_kms_key(config: GCPAssetConfig, plaintext: bytes) -> bytes: + """Encrypt a string using a KMS key via Python client.""" + client = kms.KeyManagementServiceClient() + + # Encrypt + response = client.encrypt( + request={"name": config.full_key_name, "plaintext": plaintext} + ) + + logging.debug(f"Encrypted using {config.full_key_name}") + return response.ciphertext diff --git a/cli/medperf/asset_management/gcp_utils/storage.py b/cli/medperf/asset_management/gcp_utils/storage.py new file mode 100644 index 000000000..e971c3b46 --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/storage.py @@ -0,0 +1,72 @@ +from typing import Union +from google.cloud import storage +from .types import GCPAssetConfig, GCPOperatorConfig + + +def upload_file_to_gcs( + config: Union[GCPAssetConfig, GCPOperatorConfig], local_file: str, gcs_path: str +): + """Upload file to Google Cloud Storage.""" + client = storage.Client() + bucket = client.bucket(config.bucket) + blob = bucket.blob(gcs_path) + blob.upload_from_filename(local_file) + + +def upload_string_to_gcs( + config: Union[GCPAssetConfig, GCPOperatorConfig], content: bytes, gcs_path: str +): + """Upload string content to Google Cloud Storage.""" + client = storage.Client() + bucket = client.bucket(config.bucket) + blob = bucket.blob(gcs_path) + blob.upload_from_string(content) + + +def download_file_from_gcs( + config: Union[GCPAssetConfig, GCPOperatorConfig], gcs_path: str, local_file: str +): + """Download file from Google Cloud Storage.""" + client = storage.Client() + bucket = client.bucket(config.bucket) + blob = bucket.blob(gcs_path) + blob.download_to_filename(local_file) + + +def download_string_from_gcs( + config: Union[GCPAssetConfig, GCPOperatorConfig], gcs_path: str +) -> bytes: + """Download string content from Google Cloud Storage.""" + client = storage.Client() + bucket = client.bucket(config.bucket) + blob = bucket.blob(gcs_path) + return blob.download_as_bytes() + + +def check_gcs_file_exists( + config: Union[GCPAssetConfig, GCPOperatorConfig], gcs_path: str +) -> bool: + client = storage.Client() + bucket = client.bucket(config.bucket) + return bucket.blob(gcs_path).exists() + + +def set_gcs_iam_policy(config: GCPAssetConfig, members: list[str], role: str): + client = storage.Client() + # Get current policy + + policy = client.bucket(config.bucket).get_iam_policy() + + # remove current objectviewer roles + to_remove = [] + for binding in policy.bindings: + if binding["role"] == role: + to_remove.append(binding) + + for binding in to_remove: + policy.bindings.remove(binding) + + policy.bindings.append({"role": role, "members": members}) + + # Set new policy + client.bucket(config.bucket).set_iam_policy(policy) diff --git a/cli/medperf/asset_management/gcp_utils/types.py b/cli/medperf/asset_management/gcp_utils/types.py new file mode 100644 index 000000000..db2454a7a --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/types.py @@ -0,0 +1,93 @@ +from pydantic import BaseModel + + +class CCWorkloadID(BaseModel): + data_hash: str + model_hash: str + script_hash: str + result_collector_hash: str + data_id: int + model_id: int + script_id: int + execution_id: int = None + + @property + def id(self): + return "::".join( + [ + self.script_hash, + self.data_hash, + self.model_hash, + self.result_collector_hash, + ] + ) + + @property + def id_for_model(self): + return "::".join( + [ + self.script_hash, + self.model_hash, + ] + ) + + @property + def human_readable_id(self): + if self.execution_id: + return f"d{self.data_id}-m{self.model_id}-s{self.script_id}-e{self.execution_id}" + return f"d{self.data_id}-m{self.model_id}-s{self.script_id}" + + @property + def results_path(self): + return f"{self.human_readable_id}/output" + + @property + def results_encryption_key_path(self): + return f"{self.human_readable_id}/encryption_key" + + +class GCPOperatorConfig(BaseModel): + project_id: str + service_account_name: str + bucket: str + vm_name: str + vm_zone: str + logs_poll_frequency: int = 30 # seconds + + @property + def service_account_email(self): + return f"{self.service_account_name}@{self.project_id}.iam.gserviceaccount.com" + + +class GCPAssetConfig(BaseModel): + project_id: str + project_number: str + bucket: str + encrypted_asset_bucket_file: str + encrypted_key_bucket_file: str + keyring_name: str + key_name: str + key_location: str + wip: str + wip_provider: str + + @property + def full_key_name(self) -> str: + return ( + f"projects/{self.project_id}/locations/{self.key_location}/" + f"keyRings/{self.keyring_name}/cryptoKeys/{self.key_name}" + ) + + @property + def full_wip_provider_name(self) -> str: + return ( + f"projects/{self.project_number}/locations/global/" + f"workloadIdentityPools/{self.wip}/providers/{self.wip_provider}" + ) + + @property + def full_wip_name(self) -> str: + return ( + f"projects/{self.project_number}/locations/global/" + f"workloadIdentityPools/{self.wip}" + ) diff --git a/cli/medperf/asset_management/gcp_utils/utils.py b/cli/medperf/asset_management/gcp_utils/utils.py new file mode 100644 index 000000000..7e69dfcca --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/utils.py @@ -0,0 +1,6 @@ +import google.auth + + +def get_user_credentials(): + creds, _ = google.auth.default() + return creds diff --git a/cli/medperf/asset_management/gcp_utils/workload_identity.py b/cli/medperf/asset_management/gcp_utils/workload_identity.py new file mode 100644 index 000000000..9cc9e1258 --- /dev/null +++ b/cli/medperf/asset_management/gcp_utils/workload_identity.py @@ -0,0 +1,40 @@ +import logging +from .types import GCPAssetConfig +from googleapiclient.discovery import build + + +def update_workload_identity_pool_oidc_provider( + config: GCPAssetConfig, attribute_mapping: dict, attribute_condition: str +): + # Authenticate + iam = build("iam", "v1") + + # Construct the full provider name + provider_name = config.full_wip_provider_name + + body = { + "attributeMapping": attribute_mapping, + "attributeCondition": attribute_condition, + } + + # Update the OIDC provider + try: + request = ( + iam.projects() + .locations() + .workloadIdentityPools() + .providers() + .patch( + name=provider_name, + updateMask="attributeMapping,attributeCondition", + body=body, + ) + ) + request.execute() + except Exception as e: + logging.debug(f"Failed to update OIDC provider {provider_name}: {e}") + raise + logging.debug( + f"Updated OIDC provider for workload identity pool {config.wip} " + f"with new attribute mapping and condition." + ) diff --git a/cli/medperf/asset_management/operator_check.py b/cli/medperf/asset_management/operator_check.py new file mode 100644 index 000000000..541077845 --- /dev/null +++ b/cli/medperf/asset_management/operator_check.py @@ -0,0 +1,23 @@ +from medperf.asset_management.gcp_utils import checks, get_user_credentials + + +def verify_operator_setup(sa_email, bucket_name): + base_creds = get_user_credentials() + result = checks.check_user_role_on_service_account( + base_creds, + sa_email, + "roles/iam.serviceAccountUser", + ) + if result: + return False, result + + result = checks.check_user_role_on_bucket( + "user", + base_creds, + bucket_name, + "roles/storage.objectViewer", + ) + if result: + return False, result + + return True, "" diff --git a/cli/medperf/commands/cc/cc.py b/cli/medperf/commands/cc/cc.py index 4e218f631..e6c4f437b 100644 --- a/cli/medperf/commands/cc/cc.py +++ b/cli/medperf/commands/cc/cc.py @@ -24,7 +24,7 @@ def configure_dataset_for_cc( ): """Configure dataset for confidential computing execution""" ui = config.ui - DatasetConfigureForCC.run(data_uid, cc_config_file, cc_policy_file) + DatasetConfigureForCC.run_from_files(data_uid, cc_config_file, cc_policy_file) ui.print("✅ Done!") @@ -41,7 +41,7 @@ def configure_model_for_cc( ): """Configure model for confidential computing execution""" ui = config.ui - ModelConfigureForCC.run(model_uid, cc_config_file, cc_policy_file) + ModelConfigureForCC.run_from_files(model_uid, cc_config_file, cc_policy_file) ui.print("✅ Done!") @@ -76,5 +76,5 @@ def setup_cc_operator( ): """Setup confidential computing operator""" ui = config.ui - SetupCCOperator.run(cc_config_file) + SetupCCOperator.run_from_files(cc_config_file) ui.print("✅ Done!") diff --git a/cli/medperf/commands/cc/dataset_configure_for_cc.py b/cli/medperf/commands/cc/dataset_configure_for_cc.py index e2262902f..49dd023b2 100644 --- a/cli/medperf/commands/cc/dataset_configure_for_cc.py +++ b/cli/medperf/commands/cc/dataset_configure_for_cc.py @@ -1,24 +1,28 @@ from medperf.entities.dataset import Dataset -from medperf.asset_management.asset_management import setup_dataset_for_cc +from medperf.asset_management.asset_management import ( + setup_dataset_for_cc, + validate_cc_config, +) import json from medperf import config class DatasetConfigureForCC: @classmethod - def run(cls, data_uid: int, cc_config_file: str, cc_policy_file: str): - dataset = Dataset.get(data_uid) + def run_from_files(cls, data_uid: int, cc_config_file: str, cc_policy_file: str): with open(cc_config_file) as f: cc_config = json.load(f) with open(cc_policy_file) as f: cc_policy = json.load(f) - dataset.set_cc_config(cc_config) - dataset.set_cc_policy(cc_policy) - body = {"user_metadata": dataset.user_metadata} - config.comms.update_dataset(dataset.id, body) - setup_dataset_for_cc(dataset) + cls.run(data_uid, cc_config, cc_policy) - # mark as set - dataset.mark_cc_configured() - body = {"user_metadata": dataset.user_metadata} - config.comms.update_dataset(dataset.id, body) + @classmethod + def run(cls, data_uid: int, cc_config: dict, cc_policy: dict): + validate_cc_config(cc_config, "dataset" + str(data_uid)) + with config.ui.interactive(): + dataset = Dataset.get(data_uid) + dataset.set_cc_config(cc_config) + dataset.set_cc_policy(cc_policy) + setup_dataset_for_cc(dataset) + body = {"user_metadata": dataset.user_metadata} + config.comms.update_dataset(dataset.id, body) diff --git a/cli/medperf/commands/cc/dataset_update_cc_policy.py b/cli/medperf/commands/cc/dataset_update_cc_policy.py index 6f0992c8d..69e408b0d 100644 --- a/cli/medperf/commands/cc/dataset_update_cc_policy.py +++ b/cli/medperf/commands/cc/dataset_update_cc_policy.py @@ -9,6 +9,7 @@ from medperf.entities.cube import Cube from medperf.entities.certificate import Certificate from medperf.utils import get_string_hash +from medperf.commands.certificate.utils import current_user_certificate_status import base64 @@ -16,7 +17,13 @@ def get_permitted_workloads(dataset: Dataset): user_obj = get_medperf_user_object() if dataset.owner != user_obj.id: raise MedperfException("User must be data owner") - user_cert = Certificate.get_user_certificate() + status_dict = current_user_certificate_status() + user_cert = None + if status_dict["should_be_submitted"]: + user_cert = Certificate.get_local_user_certificate() + elif status_dict["no_action_required"]: + user_cert = status_dict["user_cert_object"] + if not user_cert: raise MedperfException("User must have a certificate to update cc policy") public_key_bytes = user_cert.public_key() @@ -29,6 +36,11 @@ def get_permitted_workloads(dataset: Dataset): benchmark_id = assoc["benchmark"] benchmark = Benchmark.get(benchmark_id) evaluator = Cube.get(benchmark.data_evaluator_mlcube) + if evaluator.is_script(): + script_hash = evaluator.image_hash + else: + ref_model = Model.get(benchmark.reference_model) + script_hash = ref_model.container_obj.image_hash model_assocs = config.comms.get_benchmark_models_associations(benchmark_id) for model_assoc in model_assocs: model = Model.get(model_assoc["model"]) @@ -36,7 +48,7 @@ def get_permitted_workloads(dataset: Dataset): workload_info = CCWorkloadID( data_hash=dataset.generated_uid, model_hash=asset.asset_hash, - script_hash=evaluator.image_hash, + script_hash=script_hash, result_collector_hash=public_key_hash, data_id=dataset.id, model_id=model.id, diff --git a/cli/medperf/commands/cc/model_configure_for_cc.py b/cli/medperf/commands/cc/model_configure_for_cc.py index 916c01cf8..be950931f 100644 --- a/cli/medperf/commands/cc/model_configure_for_cc.py +++ b/cli/medperf/commands/cc/model_configure_for_cc.py @@ -1,23 +1,28 @@ from medperf.entities.model import Model -from medperf.asset_management.asset_management import setup_model_for_cc +from medperf.asset_management.asset_management import ( + setup_model_for_cc, + validate_cc_config, +) import json from medperf import config class ModelConfigureForCC: @classmethod - def run(cls, model_uid: int, cc_config_file: str, cc_policy_file: str): - model = Model.get(model_uid) + def run_from_files(cls, model_uid: int, cc_config_file: str, cc_policy_file: str): with open(cc_config_file) as f: cc_config = json.load(f) with open(cc_policy_file) as f: cc_policy = json.load(f) - model.set_cc_config(cc_config) - model.set_cc_policy(cc_policy) - body = {"user_metadata": model.user_metadata} - config.comms.update_model(model.id, body) - setup_model_for_cc(model) - # mark as set - model.mark_cc_configured() - body = {"user_metadata": model.user_metadata} - config.comms.update_model(model.id, body) + cls.run(model_uid, cc_config, cc_policy) + + @classmethod + def run(cls, model_uid: int, cc_config: dict, cc_policy: dict): + validate_cc_config(cc_config, "model" + str(model_uid)) + with config.ui.interactive(): + model = Model.get(model_uid) + model.set_cc_config(cc_config) + model.set_cc_policy(cc_policy) + setup_model_for_cc(model) + body = {"user_metadata": model.user_metadata} + config.comms.update_model(model.id, body) diff --git a/cli/medperf/commands/cc/model_update_cc_policy.py b/cli/medperf/commands/cc/model_update_cc_policy.py index f7b4c15f0..8c464c240 100644 --- a/cli/medperf/commands/cc/model_update_cc_policy.py +++ b/cli/medperf/commands/cc/model_update_cc_policy.py @@ -24,6 +24,11 @@ def get_permitted_workloads(model: Model): benchmark_id = assoc["benchmark"] benchmark = Benchmark.get(benchmark_id) evaluator = Cube.get(benchmark.data_evaluator_mlcube) + if evaluator.is_script(): + script_hash = evaluator.image_hash + else: + ref_model = Model.get(benchmark.reference_model) + script_hash = ref_model.container_obj.image_hash datasets_certs = config.comms.get_benchmark_datasets_certificates(benchmark_id) mappings = {} for cert in datasets_certs: @@ -44,7 +49,7 @@ def get_permitted_workloads(model: Model): workload_info = CCWorkloadID( data_hash=dataset.generated_uid, model_hash=asset.asset_hash, - script_hash=evaluator.image_hash, + script_hash=script_hash, result_collector_hash=public_key_hash, data_id=dataset.id, model_id=model.id, @@ -55,6 +60,37 @@ def get_permitted_workloads(model: Model): return permitted_workloads +def get_permitted_workloads_without_datasets(model: Model): + user_obj = get_medperf_user_object() + if model.owner != user_obj.id: + raise MedperfException("User must be model owner") + asset = model.asset_obj + + permitted_workloads = [] + assocs = config.comms.get_model_benchmarks_associations(model.id) + for assoc in assocs: + benchmark_id = assoc["benchmark"] + benchmark = Benchmark.get(benchmark_id) + evaluator = Cube.get(benchmark.data_evaluator_mlcube) + if evaluator.is_script(): + script_hash = evaluator.image_hash + else: + ref_model = Model.get(benchmark.reference_model) + script_hash = ref_model.container_obj.image_hash + workload_info = CCWorkloadID( + data_hash="", + model_hash=asset.asset_hash, + script_hash=script_hash, + result_collector_hash="", + data_id=1, + model_id=model.id, + script_id=evaluator.id, + ) + permitted_workloads.append(workload_info) + + return permitted_workloads + + class ModelUpdateCCPolicy: @classmethod def run(cls, model_uid: int): @@ -63,5 +99,5 @@ def run(cls, model_uid: int): raise MedperfException( f"Model {model.id} is not configured for confidential computing." ) - permitted_workloads = get_permitted_workloads(model) + permitted_workloads = get_permitted_workloads_without_datasets(model) update_model_cc_policy(model, permitted_workloads) diff --git a/cli/medperf/commands/cc/setup_cc_operator.py b/cli/medperf/commands/cc/setup_cc_operator.py index 259d2360a..a69a29aae 100644 --- a/cli/medperf/commands/cc/setup_cc_operator.py +++ b/cli/medperf/commands/cc/setup_cc_operator.py @@ -1,22 +1,25 @@ import json -from medperf.asset_management.asset_management import setup_operator +from medperf.asset_management.asset_management import ( + setup_operator, + validate_cc_operator_config, +) from medperf.account_management import get_medperf_user_object from medperf import config class SetupCCOperator: @classmethod - def run(cls, cc_config_file: str): - user = get_medperf_user_object() + def run_from_files(cls, cc_config_file: str): with open(cc_config_file) as f: cc_config = json.load(f) + cls.run(cc_config) - user.set_cc_config(cc_config) - body = {"metadata": user.metadata} - config.comms.update_user(user.id, body) - setup_operator(user) - - # mark as set - user.mark_cc_configured() - body = {"metadata": user.metadata} - config.comms.update_user(user.id, body) + @classmethod + def run(cls, cc_config: dict): + validate_cc_operator_config(cc_config) + with config.ui.interactive(): + user = get_medperf_user_object() + user.set_cc_config(cc_config) + setup_operator(user) + body = {"metadata": user.metadata} + config.comms.update_user(user.id, body) diff --git a/cli/medperf/commands/dataset/check.py b/cli/medperf/commands/dataset/check.py new file mode 100644 index 000000000..4de3c106c --- /dev/null +++ b/cli/medperf/commands/dataset/check.py @@ -0,0 +1,19 @@ +from medperf.entities.dataset import Dataset +from medperf.account_management.account_management import get_medperf_user_data +from medperf.exceptions import InvalidArgumentError +from medperf import config + + +class DataCheck: + @staticmethod + def run(dataset_uid: int): + dataset = Dataset.get(dataset_uid) + user_id = get_medperf_user_data()["id"] + if dataset.owner != user_id: + raise InvalidArgumentError("Only the dataset owner can check the hash.") + if dataset.check_hash(): + config.ui.print("✅ Data hash matches the one registered on the server.") + else: + config.ui.print( + "❌ Data hash does not match the one registered on the server." + ) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index 02bb9898e..859f0db4a 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -13,7 +13,7 @@ from medperf.commands.dataset.train import TrainingExecution from medperf.commands.dataset.import_dataset import ImportDataset from medperf.commands.dataset.export_dataset import ExportDataset - +from medperf.commands.dataset.check import DataCheck app = typer.Typer() @@ -33,8 +33,12 @@ def list( ), name: str = typer.Option(None, "--name", help="Filter by name"), owner: int = typer.Option(None, "--owner", help="Filter by owner"), - state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"), - is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"), + state: str = typer.Option( + None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)" + ), + is_valid: bool = typer.Option( + None, "--valid/--invalid", help="Filter by valid status" + ), ): """List datasets""" EntityList.run( @@ -125,6 +129,17 @@ def prepare( ui.print("✅ Done!") +@app.command("check") +@clean_except +def check( + data_uid: str = typer.Option(..., "--data_uid", "-d", help="Dataset UID"), +): + """Checks if the hash of the dataset matches the one registered the server""" + ui = config.ui + DataCheck.run(data_uid) + ui.print("✅ Done!") + + @app.command("set_operational") @clean_except def set_operational( diff --git a/cli/medperf/commands/dataset/set_operational.py b/cli/medperf/commands/dataset/set_operational.py index 2e506ddf4..b25624ac1 100644 --- a/cli/medperf/commands/dataset/set_operational.py +++ b/cli/medperf/commands/dataset/set_operational.py @@ -1,6 +1,6 @@ from medperf.entities.dataset import Dataset import medperf.config as config -from medperf.utils import approval_prompt, dict_pretty_print, get_folders_hash +from medperf.utils import approval_prompt, dict_pretty_print from medperf.exceptions import CleanExit, InvalidArgumentError import yaml @@ -32,12 +32,8 @@ def validate(self): def generate_uids(self): """Auto-generates dataset UIDs for both input and output paths""" - raw_data_path, raw_labels_path = self.dataset.get_raw_paths() - prepared_data_path = self.dataset.data_path - prepared_labels_path = self.dataset.labels_path - - in_uid = get_folders_hash([raw_data_path, raw_labels_path]) - generated_uid = get_folders_hash([prepared_data_path, prepared_labels_path]) + in_uid = self.dataset.calculate_raw_hash() + generated_uid = self.dataset.calculate_prepared_hash() self.dataset.input_data_hash = in_uid self.dataset.generated_uid = generated_uid diff --git a/cli/medperf/commands/execution/confidential_execution.py b/cli/medperf/commands/execution/confidential_execution.py index 08d284dcd..44a66ce5a 100644 --- a/cli/medperf/commands/execution/confidential_execution.py +++ b/cli/medperf/commands/execution/confidential_execution.py @@ -18,6 +18,7 @@ from medperf.asset_management.asset_management import run_workload, download_results from medperf.utils import get_string_hash from medperf.commands.certificate.utils import load_user_private_key +from medperf.containers.runners.docker_utils import full_docker_image_name class ConfidentialExecution: @@ -132,8 +133,7 @@ def setup_workload(self): def run_workload(self): config.ui.text = "Running CC workload..." docker_image = self.script.parser.get_setup_args() - # TODO: docker.io/ - docker_image = "docker.io/" + docker_image + docker_image = full_docker_image_name(docker_image) run_workload( docker_image, self.workload, diff --git a/cli/medperf/commands/execution/confidential_model_container_execution.py b/cli/medperf/commands/execution/confidential_model_container_execution.py new file mode 100644 index 000000000..b043de176 --- /dev/null +++ b/cli/medperf/commands/execution/confidential_model_container_execution.py @@ -0,0 +1,209 @@ +import base64 + +from medperf.asset_management.gcp_utils import CCWorkloadID +from medperf.entities.cube import Cube +from medperf.entities.model import Model +from medperf.entities.dataset import Dataset +from medperf.entities.execution import Execution +from medperf.entities.certificate import Certificate +import medperf.config as config +from medperf.exceptions import DecryptionError, ExecutionError + +from medperf.account_management import get_medperf_user_object +from medperf.asset_management.asset_management import ( + run_workload, + download_results, + workload_results_exists, + wait_for_workload, +) +from medperf.utils import get_string_hash +from medperf.commands.certificate.utils import ( + current_user_certificate_status, + load_user_private_key, +) +from medperf.commands.execution.container_execution import ContainerExecution +from medperf.containers.runners.docker_utils import full_docker_image_name + + +class ConfidentialModelContainerExecution: + @classmethod + def run( + cls, + benchmark_id: int, + dataset: Dataset, + model: Model, + script: Cube, + evaluator: Cube, + execution: Execution = None, + ignore_model_errors=False, + ): + """Benchmark execution flow. + + Args: + benchmark_uid (int): UID of the desired benchmark + data_uid (str): Registered Dataset UID + model_uid (int): UID of model to execute + """ + execution_flow = cls( + benchmark_id, + dataset, + model, + script, + evaluator, + execution, + ignore_model_errors, + ) + execution_flow.setup_local_environment() + with config.ui.interactive(): + execution_flow.get_operator() + execution_flow.validate() + execution_flow.prepare() + execution_flow.setup_workload() + if not execution_flow.results_exist(): + execution_flow.run_workload() + execution_flow.wait_for_workload_completion() + execution_flow.download_predictions() + execution_flow.run_evaluation() + execution_summary = execution_flow.todict() + return execution_summary + + def __init__( + self, + benchmark_id: int, + dataset: Dataset, + model: Model, + script: Cube, + evaluator: Cube, + execution: Execution = None, + ignore_model_errors=False, + ): + self.comms = config.comms + self.ui = config.ui + self.benchmark_id = benchmark_id + self.dataset = dataset + self.model = model + self.script = script + self.evaluator = evaluator + self.execution = execution + self.ignore_model_errors = ignore_model_errors + self.operator = None + self.dataset_cc_config = None + self.model_cc_config = None + self.operator_cc_config = None + self.local_execution_flow = None + + def setup_local_environment(self): + self.local_execution_flow = ContainerExecution( + self.dataset, + self.model, + self.evaluator, + self.execution, + self.ignore_model_errors, + ) + self.local_execution_flow.prepare() + + def get_operator(self): + self.operator = get_medperf_user_object() + + def validate(self): + if not self.dataset.is_cc_configured(): + raise ExecutionError( + f"Dataset {self.dataset.id} is not configured for confidential computing." + ) + if not self.model.is_cc_configured(): + raise ExecutionError( + f"Model {self.model.id} is not configured for confidential computing." + ) + if not self.operator.is_cc_configured(): + raise ExecutionError( + "User does not have a configuration to operate a confidential execution." + ) + + def prepare(self): + self.dataset_cc_config = self.dataset.get_cc_config() + self.model_cc_config = self.model.get_cc_config() + self.operator_cc_config = self.operator.get_cc_config() + self.asset = self.model.asset_obj + + def setup_workload(self): + if self.dataset.owner == self.operator.id: + status_dict = current_user_certificate_status() + user_cert = None + if status_dict["should_be_submitted"]: + user_cert = Certificate.get_local_user_certificate() + elif status_dict["no_action_required"]: + user_cert = status_dict["user_cert_object"] + + if not user_cert: + raise ExecutionError( + "User must have a certificate to run the confidential model" + ) + cert_obj = user_cert + else: + datasets_certs = config.comms.get_benchmark_datasets_certificates( + self.benchmark_id + ) + for cert in datasets_certs: + if cert["owner"]["id"] == self.dataset.owner: + cert.pop("owner") + cert_obj = Certificate(**cert) + break + else: + raise ExecutionError( + "Dataset not associated. Can't find data owner certificate." + ) + + public_key_bytes = cert_obj.public_key() + result_collector_public_key = base64.b64encode(public_key_bytes) + workload = CCWorkloadID( + data_hash=self.dataset.generated_uid, + model_hash=self.asset.asset_hash, + script_hash=self.script.image_hash, + result_collector_hash=get_string_hash(result_collector_public_key), + data_id=self.dataset.id, + model_id=self.asset.id, + script_id=self.script.id, + execution_id=self.execution.id, + ) + + self.workload = workload + self.result_collector_public_key = result_collector_public_key + + def results_exist(self): + return workload_results_exists(self.operator_cc_config, self.workload) + + def run_workload(self): + config.ui.text = "Starting Confidential VM" + docker_image = self.script.parser.get_setup_args() + docker_image = full_docker_image_name(docker_image) + run_workload( + docker_image, + self.workload, + self.dataset_cc_config, + self.model_cc_config, + self.operator_cc_config, + self.result_collector_public_key.decode("utf-8"), + ) + + def wait_for_workload_completion(self): + config.ui.text = "Waiting for workload completion" + wait_for_workload(self.workload, self.operator_cc_config) + if not self.results_exist(): + raise ExecutionError("Workload did not complete successfully.") + + def download_predictions(self): + config.ui.text = "Downloading inference predictions" + results_path = self.local_execution_flow.preds_path + private_key_bytes = load_user_private_key() + if private_key_bytes is None: + raise DecryptionError("Missing Private Key") + + download_results( + self.operator_cc_config, self.workload, private_key_bytes, results_path + ) + + def run_evaluation(self): + return self.local_execution_flow.run_evaluation() + + def todict(self): + return self.local_execution_flow.todict() diff --git a/cli/medperf/commands/execution/container_execution.py b/cli/medperf/commands/execution/container_execution.py index 8088a2ac7..2bdf0f2ed 100644 --- a/cli/medperf/commands/execution/container_execution.py +++ b/cli/medperf/commands/execution/container_execution.py @@ -102,8 +102,8 @@ def __setup_local_outputs_path(self): return local_outputs_path def set_pending_status(self): - self.__send_model_report("pending") - self.__send_evaluator_report("pending") + self.send_model_report("pending") + self.send_evaluator_report("pending") def run_inference(self): self.ui.text = f"Running inference of model '{self.model.name}' on dataset" @@ -112,7 +112,7 @@ def run_inference(self): "data_path": self.dataset.data_path, "output_path": self.preds_path, } - self.__send_model_report("started") + self.send_model_report("started") try: self.model.run( task="infer", @@ -123,7 +123,7 @@ def run_inference(self): self.ui.print("> Model execution complete") except ExecutionError as e: - self.__send_model_report("failed") + self.send_model_report("failed") if not self.ignore_model_errors: logging.error(f"Model Execution failed: {e}") raise ExecutionError(f"Model Execution failed: {e}") @@ -133,12 +133,12 @@ def run_inference(self): return except KeyboardInterrupt: logging.warning("Model Execution interrupted by user") - self.__send_model_report("interrupted") + self.send_model_report("interrupted") raise CleanExit("Model Execution interrupted by user") - self.__send_model_report("finished") + self.send_model_report("finished") def run_evaluation(self): - self.ui.text = f"Calculating metrics for model '{self.model.name}' predictions" + self.ui.text = "Calculating metrics" evaluate_timeout = config.evaluate_timeout evaluator_mounts = { "predictions": self.preds_path, @@ -146,7 +146,7 @@ def run_evaluation(self): "output_path": self.results_path, "local_outputs_path": self.local_outputs_path, } - self.__send_evaluator_report("started") + self.send_evaluator_report("started") try: self.evaluator.run( task="evaluate", @@ -156,13 +156,13 @@ def run_evaluation(self): ) except ExecutionError as e: logging.error(f"Metrics calculation failed: {e}") - self.__send_evaluator_report("failed") + self.send_evaluator_report("failed") raise ExecutionError(f"Metrics calculation failed: {e}") except KeyboardInterrupt: logging.warning("Metrics calculation interrupted by user") - self.__send_evaluator_report("interrupted") + self.send_evaluator_report("interrupted") raise CleanExit("Metrics calculation interrupted by user") - self.__send_evaluator_report("finished") + self.send_evaluator_report("finished") def todict(self): return { @@ -179,10 +179,10 @@ def get_results(self): raise ExecutionError("Results file is empty") return results - def __send_model_report(self, status: str): + def send_model_report(self, status: str): self.__send_report("model_report", status) - def __send_evaluator_report(self, status: str): + def send_evaluator_report(self, status: str): self.__send_report("evaluation_report", status) def __send_report(self, field: str, status: str): diff --git a/cli/medperf/commands/execution/create.py b/cli/medperf/commands/execution/create.py index dbaf47b40..05e725056 100644 --- a/cli/medperf/commands/execution/create.py +++ b/cli/medperf/commands/execution/create.py @@ -237,7 +237,7 @@ def run_experiments(self) -> list[Execution]: "cached": False, "error": str(e), "partial": "N/A", - "confidential": model.is_cc_mode(), + "confidential": model.requires_cc(), } ) continue @@ -254,7 +254,7 @@ def run_experiments(self) -> list[Execution]: "cached": False, "error": "", "partial": execution_summary["partial"], - "confidential": model.is_cc_mode(), + "confidential": model.requires_cc(), } ) return [experiment["execution"] for experiment in self.experiments] diff --git a/cli/medperf/commands/execution/execution_flow.py b/cli/medperf/commands/execution/execution_flow.py index 35120a79f..3a63a2a9a 100644 --- a/cli/medperf/commands/execution/execution_flow.py +++ b/cli/medperf/commands/execution/execution_flow.py @@ -2,11 +2,16 @@ from medperf.entities.model import Model from medperf.entities.dataset import Dataset from medperf.entities.execution import Execution +from medperf.entities.benchmark import Benchmark from medperf.enums import ModelType from medperf.commands.execution.container_execution import ContainerExecution from medperf.commands.execution.script_execution import ScriptExecution from medperf.commands.execution.confidential_execution import ConfidentialExecution +from medperf.commands.execution.confidential_model_container_execution import ( + ConfidentialModelContainerExecution, +) from medperf.account_management import get_medperf_user_data, is_user_logged_in +from medperf.exceptions import ExecutionError class ExecutionFlow: @@ -25,14 +30,36 @@ def run( ) if ( - model.type == ModelType.ASSET.value - and model.is_cc_mode() + evaluator.is_script() + and model.type == ModelType.ASSET.value + and model.requires_cc() and not user_is_model_owner ): return ConfidentialExecution.run( benchmark_id, dataset, model, evaluator, execution, ignore_model_errors ) + elif ( + model.type == ModelType.ASSET.value + and model.requires_cc() + and not user_is_model_owner + ): + benchmark = Benchmark.get(benchmark_id) + ref_model = Model.get(benchmark.reference_model) + script = ref_model.container_obj + return ConfidentialModelContainerExecution.run( + benchmark_id, + dataset, + model, + script, + evaluator, + execution, + ignore_model_errors, + ) elif model.type == ModelType.ASSET.value: + if not evaluator.is_script(): + raise ExecutionError( + "Running a model container with another asset model is not supported yet." + ) asset = model.asset_obj asset.prepare_asset_files() return ScriptExecution.run( diff --git a/cli/medperf/containers/parsers/parser.py b/cli/medperf/containers/parsers/parser.py index 87bc8d097..98c5c1af2 100644 --- a/cli/medperf/containers/parsers/parser.py +++ b/cli/medperf/containers/parsers/parser.py @@ -49,3 +49,7 @@ def is_docker_image(self): @abstractmethod def is_model_container(self): pass + + @abstractmethod + def is_script_container(self): + pass diff --git a/cli/medperf/containers/parsers/simple_container.py b/cli/medperf/containers/parsers/simple_container.py index 02ec653f6..1404833b7 100644 --- a/cli/medperf/containers/parsers/simple_container.py +++ b/cli/medperf/containers/parsers/simple_container.py @@ -118,3 +118,6 @@ def is_docker_image(self): def is_model_container(self): return "infer" in self.container_config["tasks"] + + def is_script_container(self): + return "run_script" in self.container_config["tasks"] diff --git a/cli/medperf/containers/runners/docker_utils.py b/cli/medperf/containers/runners/docker_utils.py index 38367e217..856e5cf06 100644 --- a/cli/medperf/containers/runners/docker_utils.py +++ b/cli/medperf/containers/runners/docker_utils.py @@ -8,6 +8,7 @@ import json import tarfile import logging +from docker.auth import resolve_repository_name def get_docker_image_hash(docker_image, timeout: int = None): @@ -177,3 +178,15 @@ def delete_images(images): run_command(delete_image_cmd) except ExecutionError: config.ui.print_warning("WARNING: Failed to delete docker images.") + + +def full_docker_image_name(image_name: str) -> str: + """ + Returns the full docker image name with registry. + If the image name does not contain a registry, it is assumed to be docker.io. + """ + logging.debug(f"Resolving full docker image name for {image_name}") + registry, name = resolve_repository_name(image_name) + resolved_name = f"{registry}/{name}" + logging.debug(f"Resolved docker image name: {resolved_name}") + return resolved_name diff --git a/cli/medperf/entities/certificate.py b/cli/medperf/entities/certificate.py index ded8212f0..8236e93d7 100644 --- a/cli/medperf/entities/certificate.py +++ b/cli/medperf/entities/certificate.py @@ -5,7 +5,7 @@ from medperf.account_management import get_medperf_user_data from medperf import config from medperf.exceptions import MedperfException -from medperf.utils import generate_tmp_path +from medperf.utils import generate_tmp_path, get_pki_assets_path import base64 from typing import List, Tuple import logging @@ -95,6 +95,26 @@ def get_user_certificate(cls): ) return user_certificates[0] + @classmethod + def get_local_user_certificate(cls): + email = get_medperf_user_data()["email"] + local_cert_folder = get_pki_assets_path(email, config.certificate_authority_id) + local_certificate_file = os.path.join( + local_cert_folder, config.certificate_file + ) + if not os.path.exists(local_certificate_file): + logging.debug(f"No local certificate found: {local_certificate_file}") + return + with open(local_certificate_file, "rb") as f: + local_certificate_content = f.read() + + cert_b64encoded = base64.b64encode(local_certificate_content).decode("utf-8") + return cls( + name="tmp_local_cert", + certificate_content_base64=cert_b64encoded, + ca=config.certificate_authority_id, + ) + @classmethod def remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index c776c0fa1..b66183192 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -101,6 +101,9 @@ def is_encrypted(self) -> bool: def is_model(self) -> bool: return self.parser.is_model_container() + def is_script(self) -> bool: + return self.parser.is_script_container() + @staticmethod def remote_prefilter(filters: dict): """Applies filtering logic that must be done before retrieving remote entities diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 27a56287e..b7d61bbf9 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -3,13 +3,14 @@ import yaml from typing import List -from medperf.utils import remove_path +from medperf.utils import get_folders_hash, remove_path from medperf.entities.interface import Entity from medperf.entities.schemas import DatasetSchema import medperf.config as config from medperf.account_management import get_medperf_user_data -from medperf.exceptions import MedperfException from medperf.entities.utils import handle_validation_error +from medperf.exceptions import InvalidEntityError +import logging class Dataset(Entity): @@ -73,7 +74,7 @@ def local_id(self): def get_cc_config(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("config", None) + return cc_values.get("config", {}) def set_cc_config(self, cc_config: dict): if "cc" not in self.user_metadata: @@ -82,24 +83,18 @@ def set_cc_config(self, cc_config: dict): def get_cc_policy(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("policy", None) + return cc_values.get("policy", {}) def set_cc_policy(self, cc_policy: dict): if "cc" not in self.user_metadata: self.user_metadata["cc"] = {} self.user_metadata["cc"]["policy"] = cc_policy - def mark_cc_configured(self): - if "cc" not in self.user_metadata: - raise MedperfException( - "Dataset does not have a cc configuration to be marked as configured" - ) - self.user_metadata["cc"]["configured"] = True - def is_cc_configured(self): - if "cc" not in self.user_metadata: - return False - return self.user_metadata["cc"].get("configured", False) + return self.get_cc_config() != {} + + def is_operational(self): + return self.state == "OPERATION" def set_raw_paths(self, raw_data_path: str, raw_labels_path: str): raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) @@ -126,6 +121,23 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) + def calculate_raw_hash(self): + raw_data_path, raw_labels_path = self.get_raw_paths() + calculated_hash = get_folders_hash([raw_data_path, raw_labels_path]) + logging.debug(f"Raw dataset calculated hash: {calculated_hash}") + return calculated_hash + + def calculate_prepared_hash(self): + calculated_hash = get_folders_hash([self.data_path, self.labels_path]) + logging.debug(f"Prepared dataset calculated hash: {calculated_hash}") + return calculated_hash + + def check_hash(self): + if not self.is_operational(): + raise InvalidEntityError("Dataset is not operational. Cannot check hash.") + calculated_hash = self.calculate_prepared_hash() + return calculated_hash == self.generated_uid + @staticmethod def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities diff --git a/cli/medperf/entities/model.py b/cli/medperf/entities/model.py index d4ddcadbc..f92f0a4e5 100644 --- a/cli/medperf/entities/model.py +++ b/cli/medperf/entities/model.py @@ -76,12 +76,13 @@ def asset_obj(self): def is_encrypted(self) -> bool: return self.is_container() and self.container_obj.is_encrypted() - def is_cc_mode(self): - return "cc" in self.user_metadata + def requires_cc(self): + # for now, let's do this + return self.is_asset() and self.asset_obj.is_local() def get_cc_config(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("config", None) + return cc_values.get("config", {}) def set_cc_config(self, cc_config: dict): if "cc" not in self.user_metadata: @@ -90,24 +91,15 @@ def set_cc_config(self, cc_config: dict): def get_cc_policy(self): cc_values = self.user_metadata.get("cc", {}) - return cc_values.get("policy", None) + return cc_values.get("policy", {}) def set_cc_policy(self, cc_policy: dict): if "cc" not in self.user_metadata: self.user_metadata["cc"] = {} self.user_metadata["cc"]["policy"] = cc_policy - def mark_cc_configured(self): - if "cc" not in self.user_metadata: - raise MedperfException( - "Model does not have a cc configuration to be marked as configured" - ) - self.user_metadata["cc"]["configured"] = True - def is_cc_configured(self): - if "cc" not in self.user_metadata: - return False - return self.user_metadata["cc"].get("configured", False) + return self.get_cc_config() != {} @staticmethod def remote_prefilter(filters: dict) -> callable: diff --git a/cli/medperf/entities/user.py b/cli/medperf/entities/user.py index 49ce1fd97..d7b836907 100644 --- a/cli/medperf/entities/user.py +++ b/cli/medperf/entities/user.py @@ -1,5 +1,4 @@ from medperf.entities.schemas import UserSchema -from medperf.exceptions import MedperfException from medperf.entities.utils import handle_validation_error @@ -31,21 +30,12 @@ def __setattr__(self, name, value): def get_cc_config(self): cc_values = self.metadata.get("cc", {}) - return cc_values.get("config", None) + return cc_values.get("config", {}) def set_cc_config(self, cc_config: dict): if "cc" not in self.metadata: self.metadata["cc"] = {} self.metadata["cc"]["config"] = cc_config - def mark_cc_configured(self): - if "cc" not in self.metadata: - raise MedperfException( - "User does not have a cc configuration to be marked as configured" - ) - self.metadata["cc"]["configured"] = True - def is_cc_configured(self): - if "cc" not in self.metadata: - return False - return self.metadata["cc"].get("configured", False) + return self.get_cc_config() != {} diff --git a/cli/medperf/tests/commands/dataset/test_set_operational.py b/cli/medperf/tests/commands/dataset/test_set_operational.py index 38adee920..fe40f237a 100644 --- a/cli/medperf/tests/commands/dataset/test_set_operational.py +++ b/cli/medperf/tests/commands/dataset/test_set_operational.py @@ -41,19 +41,19 @@ def test_validate_fails_if_dataset_is_not_marked_as_ready(mocker, set_operationa def test_generate_uids_assigns_uids_to_obj_properties(mocker, set_operational): # Arrange - in_path = ["/usr/data/path", "usr/labels/path"] - out_path = ["~/.medperf/data/123/data", "~/.medperf/data/123/labels"] - mocker.patch(PATCH_OPERATIONAL.format("get_folders_hash"), side_effect=lambda x: x) - mocker.patch.object(set_operational.dataset, "get_raw_paths", return_value=in_path) - set_operational.dataset.data_path = out_path[0] - set_operational.dataset.labels_path = out_path[1] + mocker.patch.object( + set_operational.dataset, "calculate_raw_hash", return_value="in_hash" + ) + mocker.patch.object( + set_operational.dataset, "calculate_prepared_hash", return_value="out_hash" + ) # Act set_operational.generate_uids() # Assert - assert set_operational.dataset.input_data_hash == in_path - assert set_operational.dataset.generated_uid == out_path + assert set_operational.dataset.input_data_hash == "in_hash" + assert set_operational.dataset.generated_uid == "out_hash" def test_statistics_are_updated(mocker, set_operational, fs): diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index d0aff3ad6..4c96b5910 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -658,6 +658,11 @@ def tmp_path_for_key_decryption(): return _tmp_path_for_decryption(base_path=config.container_keys_dir) +def tmp_path_for_cc_asset_key(): + """Generates a temporary file path to write key for decryption""" + return _tmp_path_for_decryption(base_path=config.cc_artifacts_dir) + + def secure_write_to_file(file_path, content: bytes, exec_permission=False): permission_mode = 0o700 if exec_permission else 0o600 with open(file_path, "wb") as f: diff --git a/cli/medperf/web_ui/benchmarks/routes.py b/cli/medperf/web_ui/benchmarks/routes.py index c913a9ca5..2f84e6431 100644 --- a/cli/medperf/web_ui/benchmarks/routes.py +++ b/cli/medperf/web_ui/benchmarks/routes.py @@ -177,11 +177,12 @@ def register_benchmark( request: Request, name: str = Form(...), description: str = Form(...), - reference_dataset_tarball_url: str = Form(...), + reference_dataset_tarball_url: str = Form(""), data_preparation_container: str = Form(...), reference_model: str = Form(...), evaluator_container: str = Form(...), skip_data_preparation_step: bool = Form(...), + skip_compatibility_tests: bool = Form(...), current_user: bool = Depends(check_user_api), ): @@ -201,7 +202,9 @@ def register_benchmark( benchmark_id = None try: benchmark_id = SubmitBenchmark.run( - benchmark_info, skip_data_preparation_step=skip_data_preparation_step + benchmark_info, + skip_data_preparation_step=skip_data_preparation_step, + skip_compatibility_tests=skip_compatibility_tests, ) return_response["status"] = "success" return_response["benchmark_id"] = benchmark_id diff --git a/cli/medperf/web_ui/datasets/routes.py b/cli/medperf/web_ui/datasets/routes.py index 45696ee00..17abfcb82 100644 --- a/cli/medperf/web_ui/datasets/routes.py +++ b/cli/medperf/web_ui/datasets/routes.py @@ -6,7 +6,7 @@ from fastapi import Request, APIRouter, Depends, Form from medperf import config -from medperf.account_management import get_medperf_user_data +from medperf.account_management import get_medperf_user_data, get_medperf_user_object from medperf.commands.mlcube.utils import check_access_to_container from medperf.commands.dataset.associate import AssociateDataset from medperf.commands.dataset.export_dataset import ExportDataset @@ -17,6 +17,8 @@ from medperf.commands.execution.create import BenchmarkExecution from medperf.commands.execution.submit import ResultSubmission from medperf.commands.execution.utils import filter_latest_executions +from medperf.commands.cc.dataset_configure_for_cc import DatasetConfigureForCC +from medperf.commands.cc.dataset_update_cc_policy import DatasetUpdateCCPolicy from medperf.entities.cube import Cube from medperf.entities.dataset import Dataset from medperf.entities.benchmark import Benchmark @@ -87,22 +89,24 @@ def dataset_detail_ui( # noqa ref_model_id = valid_benchmarks[benchmark].reference_model valid_benchmarks[benchmark].reference_model = Model.get(ref_model_id) - dataset_is_operational = dataset.state == "OPERATION" - dataset_is_prepared = ( - dataset.submitted_as_prepared or dataset.is_ready() or dataset_is_operational - ) + dataset_is_operational = dataset.is_operational() + dataset_is_prepared = dataset.is_ready() or dataset_is_operational approved_benchmarks = [ i for i in benchmark_associations if benchmark_associations[i]["approval_status"] == "APPROVED" ] - my_user_id = get_medperf_user_data()["id"] + user_obj = get_medperf_user_object() + my_user_id = user_obj.id is_owner = my_user_id == dataset.owner + dataset_hash_mismatch = None + if dataset_is_operational and is_owner: + dataset_hash_mismatch = not dataset.check_hash() # Get all results results = [] if benchmark_assocs: - user_id = get_medperf_user_data()["id"] + user_id = user_obj.id results = Execution.all(filters={"owner": user_id}) results = filter_latest_executions(results) @@ -116,8 +120,25 @@ def dataset_detail_ui( # noqa benchmark_models[assoc["benchmark"]] = models for model in models + [valid_benchmarks[assoc["benchmark"]].reference_model]: model._encrypted = model.is_encrypted() + model._requires_cc = model.requires_cc() if model._encrypted: model.access_status = check_access_to_container(model.container.id) + if model._requires_cc: + if not dataset.is_cc_configured(): + reason = "Your dataset is not configured for CC yet" + can_run = False + elif not model.is_cc_configured(): + reason = "Wait for model owner to configure their CC settings" + can_run = False + elif not user_obj.is_cc_configured(): + reason = ( + "You haven't configured your workload run settings for CC yet" + ) + can_run = False + else: + reason = "" + can_run = True + model.cc_run_status = {"can_run": can_run, "reason": reason} model.result = None for result in results: if ( @@ -133,6 +154,10 @@ def dataset_detail_ui( # noqa model.result["results"] = result.read_results() report_exists = os.path.exists(dataset.report_path) + + cc_config_defaults = dataset.get_cc_config() + cc_configured = dataset.is_cc_configured() + return templates.TemplateResponse( "dataset/dataset_detail.html", { @@ -141,12 +166,15 @@ def dataset_detail_ui( # noqa "prep_cube": prep_cube, "dataset_is_prepared": dataset_is_prepared, "dataset_is_operational": dataset_is_operational, + "dataset_hash_mismatch": dataset_hash_mismatch, "benchmark_associations": benchmark_associations, # "benchmarks": valid_benchmarks, # Benchmarks that can be associated "benchmark_models": benchmark_models, # Pass associated models without status "approved_benchmarks": approved_benchmarks, "is_owner": is_owner, "report_exists": report_exists, + "cc_config_defaults": cc_config_defaults, + "cc_configured": cc_configured, }, ) @@ -167,6 +195,7 @@ def create_dataset_ui( @router.post("/register/", response_class=JSONResponse) def register_dataset( request: Request, + submit_as_prepared: bool = Form(False), benchmark: int = Form(...), name: str = Form(...), description: str = Form(...), @@ -189,7 +218,7 @@ def register_dataset( description=description, location=location, approved=False, - submit_as_prepared=False, + submit_as_prepared=bool(submit_as_prepared), ) return_response["status"] = "success" return_response["dataset_id"] = dataset_id @@ -382,7 +411,7 @@ def export_dataset_ui( dataset.read_statistics() prep_cube = Cube.get(cube_uid=dataset.data_preparation_mlcube) dataset_is_operational = dataset.state == "OPERATION" - dataset_is_prepared = ( + dataset_is_prepared = ( # TODO: should we use submitted_as_prepared here? dataset.submitted_as_prepared or dataset.is_ready() or dataset_is_operational ) report_exists = os.path.exists(dataset.report_path) @@ -475,3 +504,65 @@ def import_dataset( ) return return_response + + +@router.post("/edit_cc_config", response_class=JSONResponse) +def edit_cc_config( + request: Request, + entity_id: int = Form(...), + require_cc: bool = Form(False), + project_id: str = Form(""), + project_number: str = Form(""), + bucket: str = Form(""), + keyring_name: str = Form(""), + key_name: str = Form(""), + key_location: str = Form(""), + wip: str = Form(""), + wip_provider: str = Form(""), + current_user: bool = Depends(check_user_api), +): + args = { + "project_id": project_id, + "project_number": project_number, + "bucket": bucket, + "keyring_name": keyring_name, + "key_name": key_name, + "key_location": key_location, + "wip": wip, + "wip_provider": wip_provider, + } + if not require_cc: + args = {} + initialize_state_task(request, task_name="data_update_cc_config") + return_response = {"status": "", "error": ""} + try: + DatasetConfigureForCC.run(entity_id, args, {}) + return_response["status"] = "success" + notification_message = "Successfully updated dataset CC config!" + except Exception as exp: + return_response["status"] = "failed" + return_response["error"] = str(exp) + notification_message = "Failed to update dataset CC config" + logger.exception(exp) + + config.ui.end_task(return_response) + reset_state_task(request) + config.ui.add_notification( + message=notification_message, + return_response=return_response, + url=f"/datasets/ui/display/{entity_id}", + ) + return return_response + + +@router.post("/sync_cc_policy", response_class=JSONResponse) +def sync_cc_policy( + entity_id: int = Form(...), + current_user: bool = Depends(check_user_api), +): + try: + DatasetUpdateCCPolicy.run(entity_id) + return {"status": "success", "error": ""} + except Exception as exp: + logger.exception(exp) + return {"status": "failed", "error": str(exp)} diff --git a/cli/medperf/web_ui/models/routes.py b/cli/medperf/web_ui/models/routes.py index 74e1eb5c9..65f827773 100644 --- a/cli/medperf/web_ui/models/routes.py +++ b/cli/medperf/web_ui/models/routes.py @@ -8,6 +8,8 @@ from medperf.entities.benchmark import Benchmark from medperf.commands.model.associate import AssociateModel from medperf.commands.mlcube.utils import check_access_to_container +from medperf.commands.cc.model_configure_for_cc import ModelConfigureForCC +from medperf.commands.cc.model_update_cc_policy import ModelUpdateCCPolicy import medperf.config as config from medperf.web_ui.common import ( check_user_api, @@ -74,6 +76,8 @@ def model_detail_ui( else: container_object = model.container_obj + cc_config_defaults = model.get_cc_config() + cc_configured = model.is_cc_configured() return templates.TemplateResponse( "model/model_detail.html", { @@ -86,6 +90,8 @@ def model_detail_ui( "is_owner": is_owner, "benchmarks_associations": benchmark_associations, # "benchmarks": benchmarks, + "cc_config_defaults": cc_config_defaults, + "cc_configured": cc_configured, }, ) @@ -117,3 +123,66 @@ def associate( url=f"/models/ui/display/{model_id}", ) return return_response + + +@router.post("/edit_cc_config", response_class=JSONResponse) +def edit_cc_config( + request: Request, + entity_id: int = Form(...), + require_cc: bool = Form(False), + project_id: str = Form(""), + project_number: str = Form(""), + bucket: str = Form(""), + keyring_name: str = Form(""), + key_name: str = Form(""), + key_location: str = Form(""), + wip: str = Form(""), + wip_provider: str = Form(""), + current_user: bool = Depends(check_user_api), +): + args = { + "project_id": project_id, + "project_number": project_number, + "bucket": bucket, + "keyring_name": keyring_name, + "key_name": key_name, + "key_location": key_location, + "wip": wip, + "wip_provider": wip_provider, + } + if not require_cc: + args = {} + + initialize_state_task(request, task_name="model_update_cc_config") + return_response = {"status": "", "error": ""} + try: + ModelConfigureForCC.run(entity_id, args, {}) + return_response["status"] = "success" + notification_message = "Successfully updated model CC config!" + except Exception as exp: + return_response["status"] = "failed" + return_response["error"] = str(exp) + notification_message = "Failed to update model CC config" + logger.exception(exp) + + config.ui.end_task(return_response) + reset_state_task(request) + config.ui.add_notification( + message=notification_message, + return_response=return_response, + url=f"/models/ui/display/{entity_id}", + ) + return return_response + + +@router.post("/sync_cc_policy", response_class=JSONResponse) +def sync_cc_policy( + entity_id: int = Form(...), + current_user: bool = Depends(check_user_api), +): + try: + ModelUpdateCCPolicy.run(entity_id) + return {"status": "success", "error": ""} + except Exception as exp: + logger.exception(exp) + return {"status": "failed", "error": str(exp)} diff --git a/cli/medperf/web_ui/settings.py b/cli/medperf/web_ui/settings.py index d188833ed..4717fd4a6 100644 --- a/cli/medperf/web_ui/settings.py +++ b/cli/medperf/web_ui/settings.py @@ -6,8 +6,10 @@ from medperf.commands.certificate.delete_client_certificate import DeleteCertificate from medperf.commands.certificate.submit import SubmitCertificate from medperf.commands.certificate.utils import current_user_certificate_status +from medperf.account_management.account_management import get_medperf_user_object from medperf.commands.utils import set_profile_args from medperf.config_management.config_management import read_config, write_config +from medperf.commands.cc.setup_cc_operator import SetupCCOperator from medperf.entities.ca import CA from medperf.exceptions import InvalidArgumentError from medperf.utils import make_pretty_dict @@ -33,11 +35,18 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)): cas = None certificate_status = None + cc_config_defaults = {} + cc_configured = False + if is_logged_in(): cas = CA.all() cas = {c.id: c.name for c in cas} certificate_status = current_user_certificate_status() + user = get_medperf_user_object() + cc_config_defaults = user.get_cc_config() + cc_configured = user.is_cc_configured() + return templates.TemplateResponse( "settings.html", { @@ -49,6 +58,8 @@ def settings_ui(request: Request, current_user: bool = Depends(check_user_ui)): "default_ca": config.certificate_authority_id, "default_fingerprint": config.certificate_authority_fingerprint, "certificate_status": certificate_status, + "cc_config_defaults": cc_config_defaults, + "cc_configured": cc_configured, }, ) @@ -206,3 +217,31 @@ def submit_certificate( return_response=return_response, ) return return_response + + +@router.post("/edit_cc_operator", response_class=JSONResponse) +def edit_cc_operator( + require_cc: bool = Form(...), + project_id: str = Form(""), + service_account_name: str = Form(""), + bucket: str = Form(""), + vm_zone: str = Form(""), + vm_name: str = Form(""), + current_user: bool = Depends(check_user_api), +): + args = { + "project_id": project_id, + "service_account_name": service_account_name, + "bucket": bucket, + "vm_zone": vm_zone, + "vm_name": vm_name, + } + if not require_cc: + args = {} + + try: + SetupCCOperator.run(args) + return {"status": "success", "error": ""} + except Exception as exp: + logger.exception(exp) + return {"status": "failed", "error": str(exp)} diff --git a/cli/medperf/web_ui/static/js/benchmarks/benchmark_register.js b/cli/medperf/web_ui/static/js/benchmarks/benchmark_register.js index 97a9512f5..9115135d0 100644 --- a/cli/medperf/web_ui/static/js/benchmarks/benchmark_register.js +++ b/cli/medperf/web_ui/static/js/benchmarks/benchmark_register.js @@ -37,6 +37,7 @@ function checkBenchmarkFormValidity() { const nameValue = $("#name").val().trim(); const descriptionValue = $("#description").val().trim(); const referenceDatasetTarballUrlValue = $("#reference-dataset-tarball-url").val().trim(); + const skipTestsValue = $("input[name='skip_compatibility_tests']:checked").val(); var dataPreparationContainerValue = $("#data-preparation-container").val(); var referenceModelValue = $("#reference-model").val(); @@ -49,7 +50,7 @@ function checkBenchmarkFormValidity() { const isValid = Boolean( nameValue.length > 0 && descriptionValue.length > 0 && - referenceDatasetTarballUrlValue.length > 0 && + (skipTestsValue === "true" ? skipTestsValue === "true" : referenceDatasetTarballUrlValue.length > 0) && dataPreparationContainerValue > 0 && referenceModelValue > 0 && evaluatorContainerValue > 0 @@ -63,4 +64,14 @@ $(document).ready(() => { }); $("#benchmark-register-form input, #benchmark-register-form textarea, #benchmark-register-form select").on("keyup change", checkBenchmarkFormValidity); + + $("input[name='skip_compatibility_tests']").on("change", () => { + if($("#skip-tests").is(":checked")){ + $("#demo-dataset-input-container").hide(); + $("#reference-dataset-tarball-url").val(""); + } + else{ + $("#demo-dataset-input-container").show(); + } + }); }); \ No newline at end of file diff --git a/cli/medperf/web_ui/static/js/cc.js b/cli/medperf/web_ui/static/js/cc.js new file mode 100644 index 000000000..7975e5571 --- /dev/null +++ b/cli/medperf/web_ui/static/js/cc.js @@ -0,0 +1,110 @@ +const fields = [ + "cc-project_id", + "cc-project_number", + "cc-bucket", + "cc-keyring_name", + "cc-key_name", + "cc-key_location", + "cc-wip", + "cc-wip_provider", +]; + +function onCCEditRequestSuccess(response){ + markAllStagesAsComplete(); + if (response.status === "success"){ + showReloadModal({ + title: "CC Configuration Edited Successfully", + seconds: 3 + }); + } + else { + showErrorModal("Failed to Edit CC Configuration", response); + } +} + +function onCCPolicyRequestSuccess(response){ + if (response.status === "success"){ + showReloadModal({ + title: "CC Policy Synced Successfully", + seconds: 3 + }); + } + else { + showErrorModal("Failed to Sync CC Policy", response); + } +} + + +function checkForCCEditChanges() { + const preferences = window.ccPreferences || {}; + const configured = preferences.can_apply; + const defaults = preferences.defaults || {}; + const requireCCChanged = ($("#require-cc").is(":checked") !== configured) + var hasChanges = fields.some(field => { + let currentValue = $(`#${field}`).val(); + let defaultValue = defaults[field] || ""; + return currentValue !== defaultValue; + }); + hasChanges = hasChanges || requireCCChanged; + $('#apply-cc-asset-btn').prop('disabled', !hasChanges); +} + +function SyncCCPolicy(syncCCPolicyBtn) { + const entityId = syncCCPolicyBtn.getAttribute("data-entity-id"); + const entityType = syncCCPolicyBtn.getAttribute("data-entity-type"); + const url = `/${entityType}s/sync_cc_policy`; + const formData = new FormData(); + formData.append("entity_id", entityId); + + disableElements(syncCCPolicyBtn); + disableElements(".card button"); + ajaxRequest( + url, + "POST", + formData, + onCCPolicyRequestSuccess, + "Error syncing CC policy:" + ); +} + +async function editCCConfig(editCCConfigBtn) { + const formData = new FormData($("#edit-cc-asset-form")[0]); + const entityId = editCCConfigBtn.getAttribute("data-entity-id"); + const entityType = editCCConfigBtn.getAttribute("data-entity-type"); + formData.append("entity_id", entityId); + const url = `/${entityType}s/edit_cc_config`; + + disableElements("#edit-cc-asset-form input, #edit-cc-asset-form button"); + disableElements(".card button"); + + ajaxRequest( + url, + "POST", + formData, + onCCEditRequestSuccess, + "Error editing CC Configuration:" + ); + showPanel(`Updating Model CC Configuration...`); + window.runningTaskId = await getTaskId(); + streamEvents(logPanel, stagesList, currentStageElement); +} + + +$(document).ready(() => { + const checkbox = $("#require-cc"); + checkbox.on("change", () => { + $("#edit-cc-asset-fields").toggle(checkbox.is(":checked")); + }); + $("#edit-cc-asset-fields").toggle(checkbox.is(":checked")); + + fields + ["require-cc"].forEach(field => $(`#${field}`).on('keyup, change', checkForCCEditChanges)); + checkForCCEditChanges(); + + $("#apply-cc-asset-btn").on("click", (e) => { + showConfirmModal(e.currentTarget, editCCConfig, "edit CC configuration?"); + }); + + $("#sync-cc-policy-btn").on("click", (e) => { + showConfirmModal(e.currentTarget, SyncCCPolicy, "sync CC policy?"); + }); +}); diff --git a/cli/medperf/web_ui/static/js/cc_operator.js b/cli/medperf/web_ui/static/js/cc_operator.js new file mode 100644 index 000000000..5e42877f8 --- /dev/null +++ b/cli/medperf/web_ui/static/js/cc_operator.js @@ -0,0 +1,65 @@ +const fields = [ + "operator-project_id", + "operator-service_account_name", + "operator-bucket", + "operator-vm_zone", + "operator-vm_name", +]; + +function checkForCCEditChanges() { + const preferences = window.ccPreferences || {}; + const configured = preferences.can_apply; + const defaults = preferences.defaults || {}; + const requireCCChanged = ($("#require-cc-operator").is(":checked") !== configured) + var hasChanges = fields.some(field => { + let currentValue = $(`#${field}`).val(); + let defaultValue = defaults[field] || ""; + return currentValue !== defaultValue; + }); + hasChanges = hasChanges || requireCCChanged; + $('#apply-cc-operator-btn').prop('disabled', !hasChanges); +} + +function editCCConfig(editCCConfigBtn) { + const formData = new FormData($("#edit-cc-operator-form")[0]); + + // TODO: properly disable elements + disableElements("#profiles-form select, #profiles-form button"); + disableElements("#edit-config-form input, #edit-config-form button, #edit-config-form select"); + disableElements("#certificate-settings button"); + + ajaxRequest( + "/settings/edit_cc_operator", + "POST", + formData, + (response) => { + if (response.status === "success") { + showReloadModal({ + title: "CC Configuration Edited Successfully", + seconds: 3 + }); + } + else { + showErrorModal("Failed to Edit CC Configuration", response); + } + }, + "Error editing CC Configuration:" + ); +} + + +$(document).ready(() => { + const checkbox = $("#require-cc-operator"); + checkbox.on("change", () => { + $("#edit-cc-operator-fields").toggle(checkbox.is(":checked")); + }); + $("#edit-cc-operator-fields").toggle(checkbox.is(":checked")); + + fields + ["require-cc-operator"].forEach(field => $(`#${field}`).on('keyup, change', checkForCCEditChanges)); + checkForCCEditChanges(); + + $("#apply-cc-operator-btn").on("click", (e) => { + showConfirmModal(e.currentTarget, editCCConfig, "edit CC configuration?"); + }); + +}); diff --git a/cli/medperf/web_ui/static/js/containers/container_register.js b/cli/medperf/web_ui/static/js/containers/container_register.js index f9914d5e2..eb7b51b3e 100644 --- a/cli/medperf/web_ui/static/js/containers/container_register.js +++ b/cli/medperf/web_ui/static/js/containers/container_register.js @@ -2,13 +2,13 @@ function onContainerRegisterSuccess(response){ markAllStagesAsComplete(); if(response.status === "success"){ showReloadModal({ - title: "Model Registered Successfully", + title: "Container Registered Successfully", seconds: 3, url: "/containers/ui/display/"+response.container_id }); } else{ - showErrorModal("Failed to Register Model", response); + showErrorModal("Failed to Register Container", response); } } diff --git a/cli/medperf/web_ui/templates/benchmark/register_benchmark.html b/cli/medperf/web_ui/templates/benchmark/register_benchmark.html index 2c08dd071..7f4516260 100644 --- a/cli/medperf/web_ui/templates/benchmark/register_benchmark.html +++ b/cli/medperf/web_ui/templates/benchmark/register_benchmark.html @@ -14,6 +14,16 @@

Register a New Benchmark

+
+
+ + +
+
+ + +
+
Register a New Benchmark >
-
+
Register a New Container
- +
- +
diff --git a/cli/medperf/web_ui/templates/dataset/dataset_detail.html b/cli/medperf/web_ui/templates/dataset/dataset_detail.html index 78936505e..2533f781a 100644 --- a/cli/medperf/web_ui/templates/dataset/dataset_detail.html +++ b/cli/medperf/web_ui/templates/dataset/dataset_detail.html @@ -5,6 +5,7 @@ {% import 'macros/container_macros.html' as container_macros %} {% import 'macros/model_macros.html' as model_macros %} {% import 'macros/benchmark_macros.html' as benchmark_macros %} +{% import 'macros/cc_asset_macro.html' as cc_asset_macro %} {% block title %}Dataset Details{% endblock title %} @@ -16,6 +17,12 @@

{{ dataset.name }}

+{% if dataset_hash_mismatch %} +

+ Data hash mismatch + +

+{% endif %}
@@ -199,6 +206,27 @@
Details
Step 3: Make a request to the benchmark owner to associate your dataset with the benchmark
+
+
+
Confidential Computing Preferences
+
+
+ {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, dataset.id, "dataset", task_running) }} +
+ {% if cc_configured %} +
+

Confidential Computing is configured, you can sync the policy.

+ +
+ {% endif %} +
{% endif %}
@@ -342,6 +370,21 @@

> Access Pending {% endif %}

+ {% endif %} + {% if model._requires_cc %} +
+  Confidential Computing Model + {% if model.cc_run_status["can_run"] %} +  Ready + {% else %} +  Not Ready + {% endif %} +
{% endif %} {{ model_macros.model_link(model) }} @@ -356,7 +399,7 @@

data-model-name="{{ model.name }}" id="run-{{ assoc }}-{{ model.id }}" {% if results_exist %} rerun="true" {% endif %} - {% if task_running or (model._encrypted and not model.access_status["has_access"]) %} disabled {% endif %} + {% if task_running or (model._encrypted and not model.access_status["has_access"]) or (model._requires_cc and not model.cc_run_status["can_run"]) %} disabled {% endif %} > ▶️ {% if results_exist %} Rerun {% else %} Run {% endif %} @@ -418,9 +461,16 @@

{% endblock detail_panel %} {% block extra_js %} + + {% if task_running and task_formData.get("dataset_id", "") == dataset.id|string %} {% if request.app.state.task.name == "dataset_preparation" %} @@ -490,4 +540,17 @@

{% endif %} {% endif %} + +{% if task_running and request.app.state.task.name == "data_update_cc_config" and task_formData.get("dataset_id", "") == entity.id|string %} + +{% endif %} + {% endblock extra_js %} diff --git a/cli/medperf/web_ui/templates/dataset/register_dataset.html b/cli/medperf/web_ui/templates/dataset/register_dataset.html index 2c07f045f..89758e8fb 100644 --- a/cli/medperf/web_ui/templates/dataset/register_dataset.html +++ b/cli/medperf/web_ui/templates/dataset/register_dataset.html @@ -151,6 +151,26 @@

Register a New Dataset

> +
+ + + +
+
+ + +{% endmacro %} \ No newline at end of file diff --git a/cli/medperf/web_ui/templates/macros/cc_operator_macro.html b/cli/medperf/web_ui/templates/macros/cc_operator_macro.html new file mode 100644 index 000000000..845721233 --- /dev/null +++ b/cli/medperf/web_ui/templates/macros/cc_operator_macro.html @@ -0,0 +1,44 @@ +{# ./cli/medperf/web_ui/templates/macros/cc_operator_macro.html #} + +{% macro cc_operator(defaults, configured, task_running=false) %} +{% set fields = [ +("project_id", "GCP Project ID"), +("service_account_name", "GCP Service Account Name"), +("bucket", "GCP Bucket Name"), +("vm_zone", "VM Zone"), +("vm_name", "VM Name"), +] %} +
+
+
+
+

+ Confidential Computing Operator Settings +

+
+
+
+ + +
+
+
+
+ {% for field_name, field_label in fields %} +
+ +
+ +
+
+ {% endfor %} +
+
+ +
+
+
+{% endmacro %} \ No newline at end of file diff --git a/cli/medperf/web_ui/templates/model/model_detail.html b/cli/medperf/web_ui/templates/model/model_detail.html index 83046038a..b6ef924be 100644 --- a/cli/medperf/web_ui/templates/model/model_detail.html +++ b/cli/medperf/web_ui/templates/model/model_detail.html @@ -4,6 +4,7 @@ {% import 'macros/association_card_macros.html' as association_card_macros %} {% import 'macros/model_detail_util_macro.html' as model_detail_util_macro %} +{% import 'macros/cc_asset_macro.html' as cc_asset_macro %} {% block title %}Model Details{% endblock %} @@ -143,6 +144,27 @@
Details
Make a request to the benchmark owner to associate your model with the benchmark +
+
+
Confidential Computing Preferences
+
+
+ {{cc_asset_macro.gcp_asset(cc_config_defaults, cc_configured, entity.id, "model", task_running) }} +
+ {% if cc_configured %} +
+

Confidential Computing is configured, you can sync the policy.

+ +
+ {% endif %} +
{% endif %} @@ -164,8 +186,15 @@

Associated Benchmarks

{% endblock detail_panel %} {% block extra_js %} + + {% if task_running and request.app.state.task.name == "model_association" and task_formData.get("model_id", "") == entity.id|string %} {% endif %} + +{% if task_running and request.app.state.task.name == "model_update_cc_config" and task_formData.get("model_id", "") == entity.id|string %} + +{% endif %} {% endblock %} \ No newline at end of file diff --git a/cli/medperf/web_ui/templates/settings.html b/cli/medperf/web_ui/templates/settings.html index dde3268a9..f630dc3f7 100644 --- a/cli/medperf/web_ui/templates/settings.html +++ b/cli/medperf/web_ui/templates/settings.html @@ -3,6 +3,7 @@ {% extends "base.html" %} {% import "constants/tooltips.html" as tooltips %} +{% import 'macros/cc_operator_macro.html' as cc_operator_macro %} {% block title %}Settings{% endblock %} @@ -140,6 +141,12 @@
Certificate Exists
Status: invalid

+ {% else %} +

+ Status: + valid +

+ {% endif %} - {% else %} -

- Status: - valid -

- {% endif %} {% endif %} {% endif %} @@ -165,11 +166,22 @@
Certificate Exists
{% include "partials/prompt_container.html" %} - +
+{% if logged_in %} + {{ cc_operator_macro.cc_operator(cc_config_defaults, cc_configured, task_running=task_running) }} +{% endif %} +
{% endblock content %} {% block extra_js %} + +