diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index c8ea4bfba..cdbc828e7 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -6,10 +6,21 @@ from medperf.entities.benchmark import Benchmark from medperf.commands.list import EntityList from medperf.commands.view import EntityView +from medperf.commands.edit import EntityEdit from medperf.commands.benchmark.submit import SubmitBenchmark from medperf.commands.benchmark.associate import AssociateBenchmark from medperf.commands.result.create import BenchmarkExecution +NAME_HELP = "Name of the benchmark" +DESC_HELP = "Description of the benchmark" +DOCS_HELP = "URL to documentation" +DEMO_URL_HELP = """Identifier to download the demonstration dataset tarball file.\n + See `medperf mlcube submit --help` for more information""" +DEMO_HASH_HELP = "SHA1 of demonstration dataset tarball file" +DATA_PREP_HELP = "Data Preparation MLCube UID" +MODEL_HELP = "Reference Model MLCube UID" +EVAL_HELP = "Evaluator MLCube UID" + app = typer.Typer() @@ -31,31 +42,61 @@ def list( @app.command("submit") @clean_except def submit( - name: str = typer.Option(..., "--name", "-n", help="Name of the benchmark"), - description: str = typer.Option( - ..., "--description", "-d", help="Description of the benchmark" + name: str = typer.Option(..., "--name", "-n", help=NAME_HELP), + description: str = typer.Option(..., "--description", "-d", help=DESC_HELP), + docs_url: str = typer.Option("", "--docs-url", "-u", help=DOCS_HELP), + demo_url: str = typer.Option("", "--demo-url", help=DEMO_URL_HELP), + demo_hash: str = typer.Option("", "--demo-hash", help=DEMO_HASH_HELP), + data_preparation_mlcube: int = typer.Option( + ..., "--data-preparation-mlcube", "-p", help=DATA_PREP_HELP ), - docs_url: str = typer.Option("", "--docs-url", "-u", help="URL to documentation"), - demo_url: str = typer.Option( - "", - "--demo-url", - help="""Identifier to download the demonstration dataset tarball file.\n - See `medperf mlcube submit --help` for more information""", + reference_model_mlcube: int = typer.Option( + ..., "--reference-model-mlcube", "-m", help=MODEL_HELP ), - demo_hash: str = typer.Option( - "", "--demo-hash", help="SHA1 of demonstration dataset tarball file" + evaluator_mlcube: int = typer.Option( + ..., "--evaluator-mlcube", "-e", help=EVAL_HELP ), +): + """Submits a new benchmark to the platform""" + benchmark_info = { + "name": name, + "description": description, + "docs_url": docs_url, + "demo_dataset_tarball_url": demo_url, + "demo_dataset_tarball_hash": demo_hash, + "data_preparation_mlcube": data_preparation_mlcube, + "reference_model_mlcube": reference_model_mlcube, + "data_evaluator_mlcube": evaluator_mlcube, + } + SubmitBenchmark.run(benchmark_info) + config.ui.print("✅ Done!") + + +@app.command("edit") +@clean_except +def edit( + entity_id: int = typer.Argument(..., help="Benchmark ID"), + name: str = typer.Option(None, "--name", "-n", help=NAME_HELP), + description: str = typer.Option(None, "--description", "-d", help=DESC_HELP), + docs_url: str = typer.Option(None, "--docs-url", "-u", help=DOCS_HELP), + demo_url: str = typer.Option(None, "--demo-url", help=DEMO_URL_HELP), + demo_hash: str = typer.Option(None, "--demo-hash", help=DEMO_HASH_HELP), data_preparation_mlcube: int = typer.Option( - ..., "--data-preparation-mlcube", "-p", help="Data Preparation MLCube UID" + None, "--data-preparation-mlcube", "-p", help=DATA_PREP_HELP ), reference_model_mlcube: int = typer.Option( - ..., "--reference-model-mlcube", "-m", help="Reference Model MLCube UID" + None, "--reference-model-mlcube", "-m", help=MODEL_HELP ), evaluator_mlcube: int = typer.Option( - ..., "--evaluator-mlcube", "-e", help="Evaluator MLCube UID" + None, "--evaluator-mlcube", "-e", help=EVAL_HELP + ), + is_valid: bool = typer.Option( + None, + "--valid/--invalid", + help="Flags a dataset valid/invalid. Invalid datasets can't be used for experiments", ), ): - """Submits a new benchmark to the platform""" + """Edits a benchmark""" benchmark_info = { "name": name, "description": description, @@ -65,8 +106,9 @@ def submit( "data_preparation_mlcube": data_preparation_mlcube, "reference_model_mlcube": reference_model_mlcube, "data_evaluator_mlcube": evaluator_mlcube, + "is_valid": is_valid, } - SubmitBenchmark.run(benchmark_info) + EntityEdit.run(Benchmark, entity_id, benchmark_info) config.ui.print("✅ Done!") @@ -84,11 +126,12 @@ def associate( ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), no_cache: bool = typer.Option( - False, "--no-cache", help="Execute the test even if results already exist", + False, + "--no-cache", + help="Execute the test even if results already exist", ), ): - """Associates a benchmark with a given mlcube or dataset. Only one option at a time. - """ + """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" AssociateBenchmark.run( benchmark_uid, model_uid, dataset_uid, approved=approval, no_cache=no_cache ) @@ -118,11 +161,12 @@ def run( help="Ignore failing model cubes, allowing for possibly submitting partial results", ), no_cache: bool = typer.Option( - False, "--no-cache", help="Execute even if results already exist", + False, + "--no-cache", + help="Execute even if results already exist", ), ): - """Runs the benchmark execution step for a given benchmark, prepared dataset and model - """ + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" BenchmarkExecution.run( benchmark_uid, data_uid, @@ -163,6 +207,5 @@ def view( help="Output file to store contents. If not provided, the output will be displayed", ), ): - """Displays the information of one or more benchmarks - """ + """Displays the information of one or more benchmarks""" EntityView.run(entity_id, Benchmark, format, local, mine, output) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index 07c97153c..a7430add2 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -6,10 +6,16 @@ from medperf.entities.dataset import Dataset from medperf.commands.list import EntityList from medperf.commands.view import EntityView +from medperf.commands.edit import EntityEdit from medperf.commands.dataset.create import DataPreparation from medperf.commands.dataset.submit import DatasetRegistration from medperf.commands.dataset.associate import AssociateDataset +NAME_HELP = "Name of the dataset" +DESC_HELP = "Description of the dataset" +LOC_HELP = "Location or Institution the data belongs to" +LOC_OPTION = typer.Option(..., "--location", help=LOC_HELP) + app = typer.Typer() @@ -43,16 +49,11 @@ def create( labels_path: str = typer.Option( ..., "--labels_path", "-l", help="Labels file location" ), - name: str = typer.Option(..., "--name", help="Name of the dataset"), - description: str = typer.Option( - ..., "--description", help="Description of the dataset" - ), - location: str = typer.Option( - ..., "--location", help="Location or Institution the data belongs to" - ), + name: str = typer.Option(..., "--name", help=NAME_HELP), + description: str = typer.Option(..., "--description", help=DESC_HELP), + location: str = typer.Option(..., "--location", help=LOC_HELP), ): - """Runs the Data preparation step for a specified benchmark and raw dataset - """ + """Runs the Data preparation step for a specified benchmark and raw dataset""" ui = config.ui data_uid = DataPreparation.run( benchmark_uid, @@ -77,8 +78,7 @@ def register( ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): - """Submits an unregistered Dataset instance to the backend - """ + """Submits an unregistered Dataset instance to the backend""" ui = config.ui uid = DatasetRegistration.run(data_uid, approved=approval) ui.print("✅ Done!") @@ -87,6 +87,30 @@ def register( ) +@app.command("edit") +@clean_except +def edit( + entity_id: int = typer.Argument(..., help="Dataset ID"), + name: str = typer.Option(None, "--name", help=NAME_HELP), + description: str = typer.Option(None, "--description", help=DESC_HELP), + location: str = typer.Option(None, "--location", help=LOC_HELP), + is_valid: bool = typer.Option( + None, + "--valid/--invalid", + help="Flags a dataset valid/invalid. Invalid datasets can't be used for experiments", + ), +): + """Edits a Dataset""" + dset_info = { + "name": name, + "description": description, + "location": location, + "is_valid": is_valid, + } + EntityEdit.run(Dataset, entity_id, dset_info) + config.ui.print("✅ Done!") + + @app.command("associate") @clean_except def associate( @@ -98,7 +122,9 @@ def associate( ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), no_cache: bool = typer.Option( - False, "--no-cache", help="Execute the test even if results already exist", + False, + "--no-cache", + help="Execute the test even if results already exist", ), ): """Associate a registered dataset with a specific benchmark. @@ -137,6 +163,5 @@ def view( help="Output file to store contents. If not provided, the output will be displayed", ), ): - """Displays the information of one or more datasets - """ + """Displays the information of one or more datasets""" EntityView.run(entity_id, Dataset, format, local, mine, output) diff --git a/cli/medperf/commands/edit.py b/cli/medperf/commands/edit.py new file mode 100644 index 000000000..1d5f1d00d --- /dev/null +++ b/cli/medperf/commands/edit.py @@ -0,0 +1,40 @@ +from medperf.entities.interface import Updatable +from medperf.exceptions import InvalidEntityError + + +class EntityEdit: + @staticmethod + def run(entity_class, id: str, fields: dict): + """Edits and updates an entity both locally and on the server if possible + + Args: + entity (Editable): Entity to modify + fields (dict): Dicitonary of fields and values to modify + """ + editor = EntityEdit(entity_class, id, fields) + editor.prepare() + editor.validate() + editor.edit() + + def __init__(self, entity_class, id, fields): + self.entity_class = entity_class + self.id = id + self.fields = fields + + def prepare(self): + self.entity = self.entity_class.get(self.id) + # Filter out empty fields + self.fields = {k: v for k, v in self.fields.items() if v is not None} + + def validate(self): + if not isinstance(self.entity, Updatable): + raise InvalidEntityError("The passed entity can't be edited") + + def edit(self): + entity = self.entity + entity.edit(**self.fields) + + if isinstance(entity, Updatable) and entity.is_registered: + entity.update() + + entity.write() diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 51c218e56..1244e2ae7 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -6,12 +6,25 @@ from medperf.entities.cube import Cube from medperf.commands.list import EntityList from medperf.commands.view import EntityView +from medperf.commands.edit import EntityEdit from medperf.commands.mlcube.create import CreateCube from medperf.commands.mlcube.submit import SubmitCube from medperf.commands.mlcube.associate import AssociateCube app = typer.Typer() +NAME_HELP = "Name of the mlcube" +MLCUBE_HELP = "Identifier to download the mlcube file. See the description above" +MLCUBE_HASH_HELP = "SHA1 of mlcube file" +PARAMS_HELP = "Identifier to download the parameters file. See the description above" +PARAMS_HASH_HELP = "SHA1 of parameters file" +ADD_HELP = ( + "Identifier to download the additional files tarball. See the description above" +) +ADD_HASH_HELP = "SHA1 of additional file" +IMG_HELP = "Identifier to download the image file. See the description above" +IMG_HASH_HELP = "SHA1 of image file" + @app.command("ls") @clean_except @@ -52,39 +65,35 @@ def create( @app.command("submit") @clean_except def submit( - name: str = typer.Option(..., "--name", "-n", help="Name of the mlcube"), + name: str = typer.Option(..., "--name", "-n", help=NAME_HELP), mlcube_file: str = typer.Option( ..., "--mlcube-file", "-m", - help="Identifier to download the mlcube file. See the description above", + help=MLCUBE_HELP, ), - mlcube_hash: str = typer.Option("", "--mlcube-hash", help="SHA1 of mlcube file"), + mlcube_hash: str = typer.Option("", "--mlcube-hash", help=MLCUBE_HASH_HELP), parameters_file: str = typer.Option( "", "--parameters-file", "-p", - help="Identifier to download the parameters file. See the description above", - ), - parameters_hash: str = typer.Option( - "", "--parameters-hash", help="SHA1 of parameters file" + help=PARAMS_HELP, ), + parameters_hash: str = typer.Option("", "--parameters-hash", help=PARAMS_HASH_HELP), additional_file: str = typer.Option( "", "--additional-file", "-a", - help="Identifier to download the additional files tarball. See the description above", - ), - additional_hash: str = typer.Option( - "", "--additional-hash", help="SHA1 of additional file" + help=ADD_HELP, ), + additional_hash: str = typer.Option("", "--additional-hash", help=ADD_HASH_HELP), image_file: str = typer.Option( "", "--image-file", "-i", - help="Identifier to download the image file. See the description above", + help=IMG_HELP, ), - image_hash: str = typer.Option("", "--image-hash", help="SHA1 of image file"), + image_hash: str = typer.Option("", "--image-hash", help=IMG_HASH_HELP), ): """Submits a new cube to the platform.\n The following assets:\n @@ -115,6 +124,64 @@ def submit( config.ui.print("✅ Done!") +@app.command("edit") +@clean_except +def edit( + entity_id: int = typer.Argument(..., help="Dataset ID"), + name: str = typer.Option(None, "--name", "-n", help=NAME_HELP), + mlcube_file: str = typer.Option( + None, + "--mlcube-file", + "-m", + help=MLCUBE_HELP, + ), + mlcube_hash: str = typer.Option(None, "--mlcube-hash", help=MLCUBE_HASH_HELP), + parameters_file: str = typer.Option( + None, + "--parameters-file", + "-p", + help=PARAMS_HELP, + ), + parameters_hash: str = typer.Option( + None, "--parameters-hash", help=PARAMS_HASH_HELP + ), + additional_file: str = typer.Option( + None, + "--additional-file", + "-a", + help=ADD_HELP, + ), + additional_hash: str = typer.Option(None, "--additional-hash", help=ADD_HASH_HELP), + image_file: str = typer.Option( + None, + "--image-file", + "-i", + help=IMG_HELP, + ), + image_hash: str = typer.Option(None, "--image-hash", help=IMG_HASH_HELP), + is_valid: bool = typer.Option( + None, + "--valid/--invalid", + help="Flags an MLCube valid/invalid. Invalid MLCubes can't be used for experiments", + ), +): + """Edits an MLCube""" + mlcube_info = { + "name": name, + "git_mlcube_url": mlcube_file, + "git_mlcube_hash": mlcube_hash, + "git_parameters_url": parameters_file, + "parameters_hash": parameters_hash, + "image_tarball_url": image_file, + "image_tarball_hash": image_hash, + "additional_files_tarball_url": additional_file, + "additional_files_tarball_hash": additional_hash, + "is_valid": is_valid, + } + EntityEdit.run(Cube, entity_id, mlcube_info) + config.ui.print("✅ Done!") + + @app.command("associate") @clean_except def associate( @@ -122,7 +189,9 @@ def associate( model_uid: int = typer.Option(..., "--model_uid", "-m", help="Model UID"), approval: bool = typer.Option(False, "-y", help="Skip approval step"), no_cache: bool = typer.Option( - False, "--no-cache", help="Execute the test even if results already exist", + False, + "--no-cache", + help="Execute the test even if results already exist", ), ): """Associates an MLCube to a benchmark""" @@ -155,6 +224,5 @@ def view( help="Output file to store contents. If not provided, the output will be displayed", ), ): - """Displays the information of one or more mlcubes - """ + """Displays the information of one or more mlcubes""" EntityView.run(entity_id, Cube, format, local, mine, output) diff --git a/cli/medperf/comms/entity_resources/resources.py b/cli/medperf/comms/entity_resources/resources.py index e0493797f..b216d6eeb 100644 --- a/cli/medperf/comms/entity_resources/resources.py +++ b/cli/medperf/comms/entity_resources/resources.py @@ -26,7 +26,9 @@ from .utils import download_resource -def get_cube(url: str, cube_path: str, expected_hash: str = None) -> str: +def get_cube( + url: str, cube_path: str, expected_hash: str = None, force: bool = False +) -> str: """Downloads and writes an mlcube.yaml file. If the hash is provided, the file's integrity will be checked upon download. @@ -34,39 +36,45 @@ def get_cube(url: str, cube_path: str, expected_hash: str = None) -> str: url (str): URL where the mlcube.yaml file can be downloaded. cube_path (str): Cube location. expected_hash (str, optional): expected sha1 hash of the downloaded file + force (bool, optional): Wether to force redownload or not Returns: output_path (str): location where the mlcube.yaml file is stored locally. hash_value (str): The hash of the downloaded file """ output_path = os.path.join(cube_path, config.cube_filename) - if os.path.exists(output_path): + if not force and os.path.exists(output_path): return output_path, expected_hash hash_value = download_resource(url, output_path, expected_hash) return output_path, hash_value -def get_cube_params(url: str, cube_path: str, expected_hash: str = None) -> str: +def get_cube_params( + url: str, cube_path: str, expected_hash: str = None, force: bool = False +) -> str: """Downloads and writes a cube parameters file. If the hash is provided, the file's integrity will be checked upon download. Args: url (str): URL where the parameters.yaml file can be downloaded. cube_path (str): Cube location. - expected_hash (str, optional): expected sha1 hash of the downloaded file + expected_hash (str, Optional): expected sha1 hash of the downloaded file + force (bool, Optional): Wether to force redownload or not Returns: output_path (str): location where the parameters file is stored locally. hash_value (str): The hash of the downloaded file """ output_path = os.path.join(cube_path, config.workspace_path, config.params_filename) - if os.path.exists(output_path): + if not force and os.path.exists(output_path): return output_path, expected_hash hash_value = download_resource(url, output_path, expected_hash) return output_path, hash_value -def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: +def get_cube_image( + url: str, cube_path: str, hash_value: str = None, force: bool = False +) -> str: """Retrieves and stores the image file from the server. Stores images on a shared location, and retrieves a cached image by hash if found locally. Creates a symbolic link to the cube storage. @@ -75,6 +83,7 @@ def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: url (str): URL where the image file can be downloaded. cube_path (str): Path to cube. hash_value (str, Optional): File hash to store under shared storage. Defaults to None. + force (bool, Optional): Wether to force redownload or not Returns: image_cube_file: Location where the image file is stored locally. @@ -98,7 +107,7 @@ def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: shutil.move(tmp_output_path, img_storage) else: img_storage = os.path.join(imgs_storage, hash_value) - if not os.path.exists(img_storage): + if force or not os.path.exists(img_storage): # If image doesn't exist locally, download it normally download_resource(url, img_storage, hash_value) @@ -108,7 +117,7 @@ def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: def get_cube_additional( - url: str, cube_path: str, expected_tarball_hash: str = None, + url: str, cube_path: str, expected_tarball_hash: str = None, force: bool = False ) -> str: """Retrieves additional files of an MLCube. The additional files will be in a compressed tarball file. The function will additionally @@ -117,7 +126,8 @@ def get_cube_additional( Args: url (str): URL where the additional_files.tar.gz file can be downloaded. cube_path (str): Cube location. - expected_tarball_hash (str, optional): expected sha1 hash of tarball file + expected_tarball_hash (str, Optional): expected sha1 hash of tarball file + force (bool, Optional): Wether to force redownload or not Returns: tarball_hash (str): The hash of the downloaded tarball file @@ -125,7 +135,10 @@ def get_cube_additional( additional_files_folder = os.path.join(cube_path, config.additional_path) if os.path.exists(additional_files_folder): - return expected_tarball_hash + if force: + shutil.rmtree(additional_files_folder) + else: + return expected_tarball_hash # make sure files are uncompressed while in tmp storage, to avoid any clutter # objects if uncompression fails for some reason. diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index f934435d5..17eab4616 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -133,6 +133,14 @@ def upload_benchmark(self, benchmark_dict: dict) -> int: int: UID of newly created benchmark """ + @abstractmethod + def update_benchmark(self, id: int, benchmark_dict: dict): + """Updates the benchmark with the given id and the new dictionary + + Args: + benchmark_dict (dict): updated benchmark data + """ + @abstractmethod def upload_mlcube(self, mlcube_body: dict) -> int: """Uploads an MLCube instance to the platform @@ -144,6 +152,14 @@ def upload_mlcube(self, mlcube_body: dict) -> int: int: id of the created mlcube instance on the platform """ + @abstractmethod + def update_mlcube(self, id: int, mlcube_dict: dict): + """Updates the mlcube with the given id and the new dictionary + + Args: + mlcube_dict (dict): updated mlcube data + """ + @abstractmethod def get_datasets(self) -> List[dict]: """Retrieves all datasets in the platform @@ -182,6 +198,14 @@ def upload_dataset(self, reg_dict: dict) -> int: int: id of the created dataset registration. """ + @abstractmethod + def update_dataset(self, id: int, dataset_dict: dict): + """Updates the dataset with the given id and the new dictionary + + Args: + dataset_dict (dict): updated dataset data + """ + @abstractmethod def get_results(self) -> List[dict]: """Retrieves all results diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 3d4c85d71..0b6e94755 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -295,6 +295,17 @@ def upload_benchmark(self, benchmark_dict: dict) -> int: raise CommunicationRetrievalError(f"Could not upload benchmark: {details}") return res.json() + def update_benchmark(self, id: int, benchmark_dict: dict): + """Updates the benchmark with the given id and the new dictionary + + Args: + benchmark_dict (dict): updated benchmark data + """ + res = self.__auth_put(f"{self.server_url}/benchmarks/{id}/", json=benchmark_dict) + if res.status_code != 200: + log_response_error(res) + raise CommunicationRequestError(f"Could not update benchmark: {res.text}") + def upload_mlcube(self, mlcube_body: dict) -> int: """Uploads an MLCube instance to the platform @@ -311,6 +322,17 @@ def upload_mlcube(self, mlcube_body: dict) -> int: raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}") return res.json() + def update_mlcube(self, id: int, mlcube_dict: dict): + """Updates the mlcube with the given id and the new dictionary + + Args: + mlcube_dict (dict): updated mlcube data + """ + res = self.__auth_put(f"{self.server_url}/mlcubes/{id}/", json=mlcube_dict) + if res.status_code != 200: + log_response_error(res) + raise CommunicationRequestError(f"Could not update mlcube: {res.text}") + def get_datasets(self) -> List[dict]: """Retrieves all datasets in the platform @@ -363,6 +385,17 @@ def upload_dataset(self, reg_dict: dict) -> int: raise CommunicationRequestError(f"Could not upload the dataset: {details}") return res.json() + def update_dataset(self, id: int, dataset_dict: dict): + """Updates the dataset with the given id and the new dictionary + + Args: + dataset_dict (dict): updated dataset data + """ + res = self.__auth_put(f"{self.server_url}/datasets/{id}/", json=dataset_dict) + if res.status_code != 200: + log_response_error(res) + raise CommunicationRequestError(f"Could not update dataset: {res.text}") + def get_results(self) -> List[dict]: """Retrieves all results diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index b2ea1496a..9a956bc83 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -2,17 +2,18 @@ from medperf.exceptions import MedperfException import yaml import logging +from deepdiff import DeepDiff from typing import List, Optional, Union from pydantic import HttpUrl, Field, validator import medperf.config as config -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity, Updatable from medperf.utils import storage_path from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema -class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, Updatable, MedperfSchema, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -231,6 +232,66 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: """ return config.comms.get_benchmark_models(benchmark_uid) + def edit(self, **kwargs): + """Edits a benchmark with the given property-value pairs""" + data = self.todict() + data.update(kwargs) + new_bmk = Benchmark(**data) + + self.__validate_edit(new_bmk) + + self.__dict__.update(**new_bmk.__dict__) + + def __validate_edit(self, new_bmk: "Benchmark"): + """Validates that an update is valid given the changes made + + Args: + old_bmk (Benchmark): The old version of the Benchmark + new_bmk (Benchmark): The new version of the same Benchmark + + Raises: + InvalidArgumentError: The changed fields are not mutable + """ + old_bmk = self + # Field that shouldn't ber modified directly by the user + inmutable_fields = { + "id", + } + + # Fields that can no longer be modified while in production + production_inmutable_fields = { + "name", + "description" "demo_dataset_tarball_hash", + "demo_dataset_generated_uid", + "data_preparation_mlcube", + "reference_model_mlcube", + "data_evaluator_mlcube", + } + + if old_bmk.state == "OPERATION": + inmutable_fields = inmutable_fields.union(production_inmutable_fields) + + bmk_diffs = DeepDiff(new_bmk.todict(), old_bmk.todict()) + updated_fields = set(bmk_diffs.affected_root_keys) + + updated_inmutable_fields = updated_fields.intersection(inmutable_fields) + + if len(updated_inmutable_fields): + fields_msg = ", ".join(updated_inmutable_fields) + msg = ( + "The following fields can't be directly edited: " + + fields_msg + + ". For these changes, a new Benchmark is required" + ) + raise InvalidArgumentError(msg) + + def update(self): + """Updates the benchmark on the server""" + if not self.is_registered: + raise MedperfException("Can't update an unregistered benchmark") + body = self.todict() + config.comms.update_benchmark(self.id, body) + def todict(self) -> dict: """Dictionary representation of the benchmark instance diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index b71291e78..d9d3df35c 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -2,12 +2,13 @@ import yaml import pexpect import logging +from deepdiff import DeepDiff from typing import List, Dict, Optional, Union from pydantic import Field from pathlib import Path from medperf.utils import combine_proc_sp_text, list_files, storage_path -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity, Updatable from medperf.entities.schemas import MedperfSchema, DeployableSchema from medperf.exceptions import ( InvalidArgumentError, @@ -20,7 +21,7 @@ from medperf.comms.entity_resources import resources -class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Cube(Entity, Updatable, MedperfSchema, DeployableSchema): """ Class representing an MLCube Container @@ -177,35 +178,37 @@ def __local_get(cls, cube_uid: Union[str, int]) -> "Cube": cube = cls(**local_meta) return cube - def download_mlcube(self): + def download_mlcube(self, force=False): url = self.git_mlcube_url - path, file_hash = resources.get_cube(url, self.path, self.mlcube_hash) + path, file_hash = resources.get_cube( + url, self.path, self.mlcube_hash, force=force + ) self.cube_path = path self.mlcube_hash = file_hash - def download_parameters(self): + def download_parameters(self, force=False): url = self.git_parameters_url if url: path, file_hash = resources.get_cube_params( - url, self.path, self.parameters_hash + url, self.path, self.parameters_hash, force=force ) self.params_path = path self.parameters_hash = file_hash - def download_additional(self): + def download_additional(self, force=False): url = self.additional_files_tarball_url if url: file_hash = resources.get_cube_additional( - url, self.path, self.additional_files_tarball_hash + url, self.path, self.additional_files_tarball_hash, force=force ) self.additional_files_tarball_hash = file_hash - def download_image(self): + def download_image(self, force=False): url = self.image_tarball_url hash = self.image_tarball_hash if url: - _, local_hash = resources.get_cube_image(url, self.path, hash) + _, local_hash = resources.get_cube_image(url, self.path, hash, force=force) self.image_tarball_hash = local_hash else: # Retrieve image from image registry @@ -218,7 +221,6 @@ def download_image(self): def download(self): """Downloads the required elements for an mlcube to run locally.""" - try: self.download_mlcube() except InvalidEntityError as e: @@ -307,6 +309,84 @@ def get_default_output(self, task: str, out_key: str, param_key: str = None) -> return out_path + def edit(self, **kwargs): + """Edits a cube with the given property-value pairs""" + data = self.todict() + + # Include the updated fields + data.update(kwargs) + new_cube = Cube(**data) + + # If any resource is being updated, download and get the new hash + # Hash difference checking is done between the old and new cube + # According to update policies + if "git_mlcube_url" in kwargs: + new_cube.mlcube_hash = kwargs.get("mlcube_hash", None) + new_cube.download_mlcube(force=True) + if "git_parameters_url" in kwargs: + new_cube.parameters_hash = kwargs.get("parameters_hash", None) + new_cube.download_parameters(force=True) + if "image_tarball_url" in kwargs: + new_cube.image_tarball_hash = kwargs.get("image_tarball_hash", None) + new_cube.download_image(force=True) + if "additional_files_tarball_url" in kwargs: + new_cube.additional_files_tarball_hash = kwargs.get( + "additional_files_tarball_hash", None + ) + new_cube.download_additional(force=True) + + self.__validate_edit(new_cube) + + self.__dict__.update(**new_cube.__dict__) + + def __validate_edit(self, new_cube: "Cube"): + """Ensure an edit is valid given the changes made + + Args: + new_cube (Cube): The new version of the same MLCube + + Raises: + InvalidEntityError: The changes created an invalid entity configuration + InvalidArugmentError: The changed fields are not mutable + """ + old_cube = self + # Fields that shouldn't be modified directly by the user + inmutable_fields = { + "id", + } + + # Fields that can no longer be modified while in production + production_inmutable_fields = { + "mlcube_hash", + "parameters_hash", + "image_tarball_hash", + "additional_files_tarball_hash", + } + + if old_cube.state == "OPERATION": + inmutable_fields = inmutable_fields.union(production_inmutable_fields) + + cube_diffs = DeepDiff(new_cube.todict(), old_cube.todict()) + updated_fields = set(cube_diffs.affected_root_keys) + + updated_inmutable_fields = updated_fields.intersection(inmutable_fields) + + if len(updated_inmutable_fields): + fields_msg = ", ".join(updated_inmutable_fields) + msg = ( + "The following fields can't be directly edited: " + + fields_msg + + ". For these changes, a new MLCube is required" + ) + raise InvalidArgumentError(msg) + + def update(self): + """Updates the benchmark on the server""" + if not self.is_registered: + raise MedperfException("Can't update an unregistered cube") + body = self.todict() + config.comms.update_mlcube(self.id, body) + def todict(self) -> Dict: return self.extended_dict() diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 9674237e1..aa101e57e 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,12 +1,13 @@ import os import yaml import logging +from deepdiff import DeepDiff from pydantic import Field, validator from typing import List, Optional, Union from medperf.utils import storage_path from medperf.enums import Status -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity, Updatable from medperf.entities.schemas import MedperfSchema, DeployableSchema from medperf.exceptions import ( InvalidArgumentError, @@ -16,7 +17,7 @@ import medperf.config as config -class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Dataset(Entity, Updatable, MedperfSchema, DeployableSchema): """ Class representing a Dataset @@ -219,6 +220,62 @@ def upload(self): updated_dataset_dict["separate_labels"] = dataset_dict["separate_labels"] return updated_dataset_dict + def edit(self, **kwargs): + """Edits a dataset with the given property-value pairs""" + data = self.todict() + data.update(kwargs) + new_dset = Dataset(**data) + + self.__validate_edit(new_dset) + + self.__dict__.update(**new_dset.__dict__) + + def __validate_edit(self, new_dset: "Dataset"): + """Determines if an update is valid given the changes made + + Args: + new_dset (Dataset): The updated version of the same dataset + + Raises: + InvalidArugmentError: The changed fields are not mutable + """ + old_dset = self + # Fields that shouldn't be modified directly by the user + inmutable_fields = { + "id", + "input_data_hash", + "generated_uid", + "separate_labels", + "generated_metadata", + "data_preparation_mlcube", + } + + # Fields that can no longer be modified while in production + production_inmutable_fields = {"name", "split_seed", "description", "location"} + + if old_dset.state == "OPERATION": + inmutable_fields = inmutable_fields.union(production_inmutable_fields) + + dset_diffs = DeepDiff(new_dset.todict(), old_dset.todict()) + updated_fields = set(dset_diffs.affected_root_keys) + updated_inmutable_fields = updated_fields.intersection(inmutable_fields) + + if len(updated_inmutable_fields): + fields_msg = ", ".join(updated_inmutable_fields) + msg = ( + "The following fields can't be directly edited: " + + fields_msg + + ". For these changes, a new Dataset is required" + ) + raise InvalidArgumentError(msg) + + def update(self): + """Updates the benchmark on the server""" + if not self.is_registered: + raise MedperfException("Can't update an unregistered dataset") + body = self.todict() + config.comms.update_dataset(self.id, body) + @classmethod def __get_local_dict(cls, data_uid): dataset_path = os.path.join(storage_path(config.data_storage), str(data_uid)) diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..95c716382 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -75,3 +75,17 @@ def identifier(self): @property def is_registered(self): return self.id is not None + + +class Updatable(Uploadable): + @abstractmethod + def edit(self, **kwargs): + """Edits the current entity with the given fields + + Arguments: + kwargs (dict): Key-value pair of properties to edit and their corresponding new values + """ + + @abstractmethod + def update(self): + """Updates the current entity on the server""" diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index 27abe0ee5..54de07085 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -91,7 +91,7 @@ def name_max_length(cls, v, *, values, **kwargs): class DeployableSchema(BaseModel): # TODO: This must change after allowing edits - state: str = "OPERATION" + state: str = "DEVELOPMENT" is_valid: bool = True diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index bae15b4f7..f82df058e 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -85,7 +85,7 @@ def setup_cube_comms(mocker, comms, all_ents, user_ents, uploaded): def generate_cubefile_fn(fs, path, filename): # all_ids = [ent["id"] if type(ent) == dict else ent for ent in all_ents] - def cubefile_fn(url, cube_path, *args): + def cubefile_fn(url, cube_path, *args, **kwargs): if url == "broken_url": raise CommunicationRetrievalError filepath = os.path.join(cube_path, path, filename) diff --git a/cli/medperf/tests/mocks/cube.py b/cli/medperf/tests/mocks/cube.py index 0ad5ba326..885c388e2 100644 --- a/cli/medperf/tests/mocks/cube.py +++ b/cli/medperf/tests/mocks/cube.py @@ -38,4 +38,4 @@ class TestCube(Cube): str ] = "https://test.com/additional_files.tar.gz" additional_files_tarball_hash: Optional[str] = EMPTY_FILE_HASH - state: str = "PRODUCTION" + state: str = "OPERATION" diff --git a/cli/medperf/tests/mocks/dataset.py b/cli/medperf/tests/mocks/dataset.py index 9fb79d5ff..7a3d7528d 100644 --- a/cli/medperf/tests/mocks/dataset.py +++ b/cli/medperf/tests/mocks/dataset.py @@ -12,4 +12,4 @@ class TestDataset(Dataset): generated_uid: str = "generated_uid" generated_metadata: dict = {} status: Status = Status.APPROVED.value - state: str = "PRODUCTION" + state: str = "OPERATION" diff --git a/cli/requirements.txt b/cli/requirements.txt index ad66257c7..68f4a2eea 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -16,4 +16,5 @@ mlcube-singularity==0.0.9 validators==0.18.2 merge-args==0.1.4 synapseclient==2.7.0 +deepdiff==6.3.0 schema==0.7.5