From ec41369ba065ae4212323938bcf1cb869dbefd0a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Sat, 27 Apr 2024 07:21:32 +0200 Subject: [PATCH 01/18] refactor entities --- cli/cli_tests.sh | 4 +- cli/medperf/commands/benchmark/benchmark.py | 16 +- .../compatibility_test/compatibility_test.py | 8 +- .../commands/compatibility_test/utils.py | 16 +- cli/medperf/commands/dataset/dataset.py | 16 +- cli/medperf/commands/list.py | 14 +- cli/medperf/commands/mlcube/mlcube.py | 16 +- cli/medperf/commands/result/create.py | 5 +- cli/medperf/commands/result/result.py | 18 +- cli/medperf/commands/view.py | 12 +- cli/medperf/entities/benchmark.py | 212 ++---------------- cli/medperf/entities/cube.py | 161 +++---------- cli/medperf/entities/dataset.py | 187 +++------------ cli/medperf/entities/interface.py | 206 ++++++++++++++--- cli/medperf/entities/report.py | 92 ++------ cli/medperf/entities/result.py | 182 ++------------- cli/medperf/entities/schemas.py | 4 +- .../tests/commands/result/test_create.py | 3 + cli/medperf/tests/commands/test_list.py | 8 +- cli/medperf/tests/commands/test_view.py | 211 +++++++---------- cli/medperf/tests/entities/test_benchmark.py | 5 +- cli/medperf/tests/entities/test_cube.py | 13 +- cli/medperf/tests/entities/test_entity.py | 73 +++--- cli/medperf/tests/entities/utils.py | 85 ++++--- 24 files changed, 560 insertions(+), 1007 deletions(-) diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index ac6137b65..68764618a 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -5,7 +5,6 @@ ################### Start Testing ######################## ########################################################## - ########################################################## echo "==========================================" echo "Printing MedPerf version" @@ -186,7 +185,7 @@ echo "Running data submission step" echo "=====================================" medperf dataset submit -p $PREP_UID -d $DIRECTORY/dataset_a -l $DIRECTORY/dataset_a --name="dataset_a" --description="mock dataset a" --location="mock location a" -y checkFailed "Data submission step failed" -DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | cut -d ' ' -f 1) +DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -212,7 +211,6 @@ DSET_A_GENUID=$(medperf dataset view $DSET_A_UID | grep generated_uid | cut -d " echo "\n" - ########################################################## echo "=====================================" echo "Moving storage to some other location" diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index f02d67cb4..35d719b0d 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local benchmarks"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered benchmarks" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), ): - """List benchmarks stored locally and remotely from the user""" + """List benchmarks""" EntityList.run( Benchmark, fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -162,10 +164,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( + unregistered: bool = typer.Option( False, - "--local", - help="Display local benchmarks if benchmark ID is not provided", + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", ), mine: bool = typer.Option( False, @@ -180,4 +182,4 @@ def view( ), ): """Displays the information of one or more benchmarks""" - EntityView.run(entity_id, Benchmark, format, local, mine, output) + EntityView.run(entity_id, Benchmark, format, unregistered, mine, output) diff --git a/cli/medperf/commands/compatibility_test/compatibility_test.py b/cli/medperf/commands/compatibility_test/compatibility_test.py index a3b25ac78..0bd4a4695 100644 --- a/cli/medperf/commands/compatibility_test/compatibility_test.py +++ b/cli/medperf/commands/compatibility_test/compatibility_test.py @@ -95,7 +95,11 @@ def run( @clean_except def list(): """List previously executed tests reports.""" - EntityList.run(TestReport, fields=["UID", "Data Source", "Model", "Evaluator"]) + EntityList.run( + TestReport, + fields=["UID", "Data Source", "Model", "Evaluator"], + unregistered=True, + ) @app.command("view") @@ -116,4 +120,4 @@ def view( ), ): """Displays the information of one or more test reports""" - EntityView.run(entity_id, TestReport, format, output=output) + EntityView.run(entity_id, TestReport, format, unregistered=True, output=output) diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index a12ac5ea2..c56a57d41 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -138,23 +138,23 @@ def create_test_dataset( # TODO: existing dataset could make problems # make some changes since this is a test dataset config.tmp_paths.remove(data_creation.dataset.path) - data_creation.dataset.write() if skip_data_preparation_step: data_creation.make_dataset_prepared() dataset = data_creation.dataset + old_generated_uid = dataset.generated_uid + old_path = dataset.path # prepare/check dataset DataPreparation.run(dataset.generated_uid) # update dataset generated_uid - old_path = dataset.path - generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) - dataset.generated_uid = generated_uid - dataset.write() - if dataset.input_data_hash != dataset.generated_uid: + new_generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) + if new_generated_uid != old_generated_uid: # move to a correct location if it underwent preparation - new_path = old_path.replace(dataset.input_data_hash, generated_uid) + new_path = old_path.replace(old_generated_uid, new_generated_uid) remove_path(new_path) os.rename(old_path, new_path) + dataset.generated_uid = new_generated_uid + dataset.write() - return generated_uid + return new_generated_uid diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index a27e36814..fc18022ac 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -17,17 +17,19 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local datasets"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered datasets" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"), mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), ): - """List datasets stored locally and remotely from the user""" + """List datasets""" EntityList.run( Dataset, fields=["UID", "Name", "Data Preparation Cube UID", "State", "Status", "Owner"], - local_only=local, + unregistered=unregistered, mine_only=mine, mlcube=mlcube, ) @@ -149,8 +151,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local datasets if dataset ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered datasets if dataset ID is not provided", ), mine: bool = typer.Option( False, @@ -165,4 +169,4 @@ def view( ), ): """Displays the information of one or more datasets""" - EntityView.run(entity_id, Dataset, format, local, mine, output) + EntityView.run(entity_id, Dataset, format, unregistered, mine, output) diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 5fd462bf7..b5d6226a4 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -10,27 +10,29 @@ class EntityList: def run( entity_class, fields, - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, **kwargs, ): """Lists all local datasets Args: - local_only (bool, optional): Display all local results. Defaults to False. + unregistered (bool, optional): Display only local unregistered results. Defaults to False. mine_only (bool, optional): Display all current-user results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ - entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs) + entity_list = EntityList( + entity_class, fields, unregistered, mine_only, **kwargs + ) entity_list.prepare() entity_list.validate() entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, local_only, mine_only, **kwargs): + def __init__(self, entity_class, fields, unregistered, mine_only, **kwargs): self.entity_class = entity_class self.fields = fields - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.filters = kwargs self.data = [] @@ -40,7 +42,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.display_dict() for entity in entities] diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 4c365e574..9256f35f2 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local mlcubes"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered mlcubes" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), ): - """List mlcubes stored locally and remotely from the user""" + """List mlcubes""" EntityList.run( Cube, fields=["UID", "Name", "State", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -148,8 +150,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local mlcubes if mlcube ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered mlcubes if mlcube ID is not provided", ), mine: bool = typer.Option( False, @@ -164,4 +168,4 @@ def view( ), ): """Displays the information of one or more mlcubes""" - EntityView.run(entity_id, Cube, format, local, mine, output) + EntityView.run(entity_id, Cube, format, unregistered, mine, output) diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 42f97d990..760dddc94 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from medperf.account_management.account_management import get_medperf_user_data from medperf.commands.execution import Execution from medperf.entities.result import Result from tabulate import tabulate @@ -143,7 +144,9 @@ def __validate_models(self, benchmark_models): raise InvalidArgumentError(msg) def load_cached_results(self): - results = Result.all() + user_id = get_medperf_user_data()["id"] + results = Result.all(filters={"owner": user_id}) + results += Result.all(unregistered=True) benchmark_dset_results = [ result for result in results diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 6fbb3b08a..40b65c52e 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -62,17 +62,19 @@ def submit( @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local results"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered results" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user results"), benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), ): - """List results stored locally and remotely from the user""" + """List results""" EntityList.run( Result, fields=["UID", "Benchmark", "Model", "Dataset", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, benchmark=benchmark, ) @@ -88,8 +90,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local results if result ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered results if result ID is not provided", ), mine: bool = typer.Option( False, @@ -107,4 +111,6 @@ def view( ), ): """Displays the information of one or more results""" - EntityView.run(entity_id, Result, format, local, mine, output, benchmark=benchmark) + EntityView.run( + entity_id, Result, format, unregistered, mine, output, benchmark=benchmark + ) diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index b4c242f0a..8c2a4179f 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -14,7 +14,7 @@ def run( entity_id: Union[int, str], entity_class: Entity, format: str = "yaml", - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, output: str = None, **kwargs, @@ -24,14 +24,14 @@ def run( Args: entity_id (Union[int, str]): Entity identifies entity_class (Entity): Entity type - local_only (bool, optional): Display all local entities. Defaults to False. + unregistered (bool, optional): Display only local unregistered entities. Defaults to False. mine_only (bool, optional): Display all current-user entities. Defaults to False. format (str, optional): What format to use to display the contents. Valid formats: [yaml, json]. Defaults to yaml. output (str, optional): Path to a file for storing the entity contents. If not provided, the contents are printed. kwargs (dict): Additional parameters for filtering entity lists. """ entity_view = EntityView( - entity_id, entity_class, format, local_only, mine_only, output, **kwargs + entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ) entity_view.validate() entity_view.prepare() @@ -41,12 +41,12 @@ def run( entity_view.store() def __init__( - self, entity_id, entity_class, format, local_only, mine_only, output, **kwargs + self, entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ): self.entity_id = entity_id self.entity_class = entity_class self.format = format - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.output = output self.filters = kwargs @@ -65,7 +65,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.todict() for entity in entities] diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 849ea3fcd..1d33efa95 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,18 +1,13 @@ -import os -from medperf.exceptions import MedperfException -import yaml -import logging -from typing import List, Optional, Union +from typing import List, Optional from pydantic import HttpUrl, Field import medperf.config as config -from medperf.entities.interface import Entity, Uploadable -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -35,6 +30,26 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS user_metadata: dict = {} is_active: bool = True + @staticmethod + def get_type(): + return "benchmark" + + @staticmethod + def get_storage_path(): + return config.benchmarks_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_benchmark + + @staticmethod + def get_metadata_filename(): + return config.benchmarks_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_benchmark + def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -44,53 +59,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" - path = config.benchmarks_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: - """Gets and creates instances of all retrievable benchmarks - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Benchmark]: a list of Benchmark instances. - """ - logging.info("Retrieving all benchmarks") - benchmarks = [] - - if not local_only: - benchmarks = cls.__remote_all(filters=filters) - - remote_uids = set([bmk.id for bmk in benchmarks]) - - local_benchmarks = cls.__local_all() - - benchmarks += [bmk for bmk in local_benchmarks if bmk.id not in remote_uids] - - return benchmarks @classmethod - def __remote_all(cls, filters: dict) -> List["Benchmark"]: - benchmarks = [] - try: - comms_fn = cls.__remote_prefilter(filters) - bmks_meta = comms_fn() - benchmarks = [cls(**meta) for meta in bmks_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all benchmarks from the server" - logging.warning(msg) - - return benchmarks - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,104 +75,6 @@ def __remote_prefilter(cls, filters: dict) -> callable: comms_fn = config.comms.get_user_benchmarks return comms_fn - @classmethod - def __local_all(cls) -> List["Benchmark"]: - benchmarks = [] - bmks_storage = config.benchmarks_folder - try: - uids = next(os.walk(bmks_storage))[1] - except StopIteration: - msg = "Couldn't iterate over benchmarks directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - benchmark = cls(**meta) - benchmarks.append(benchmark) - - return benchmarks - - @classmethod - def get( - cls, benchmark_uid: Union[str, int], local_only: bool = False - ) -> "Benchmark": - """Retrieves and creates a Benchmark instance from the server. - If benchmark already exists in the platform then retrieve that - version. - - Args: - benchmark_uid (str): UID of the benchmark. - comms (Comms): Instance of a communication interface. - - Returns: - Benchmark: a Benchmark instance with the retrieved data. - """ - - if not str(benchmark_uid).isdigit() or local_only: - return cls.__local_get(benchmark_uid) - - try: - return cls.__remote_get(benchmark_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Benchmark {benchmark_uid} from comms failed") - logging.info(f"Looking for benchmark {benchmark_uid} locally") - return cls.__local_get(benchmark_uid) - - @classmethod - def __remote_get(cls, benchmark_uid: int) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") - benchmark_dict = config.comms.get_benchmark(benchmark_uid) - benchmark = cls(**benchmark_dict) - benchmark.write() - return benchmark - - @classmethod - def __local_get(cls, benchmark_uid: Union[str, int]) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} locally") - benchmark_dict = cls.__get_local_dict(benchmark_uid) - benchmark = cls(**benchmark_dict) - return benchmark - - @classmethod - def __get_local_dict(cls, benchmark_uid) -> dict: - """Retrieves a local benchmark information - - Args: - benchmark_uid (str): uid of the local benchmark - - Returns: - dict: information of the benchmark - """ - logging.info(f"Retrieving benchmark {benchmark_uid} from local storage") - storage = config.benchmarks_folder - bmk_storage = os.path.join(storage, str(benchmark_uid)) - bmk_file = os.path.join(bmk_storage, config.benchmarks_filename) - if not os.path.exists(bmk_file): - raise InvalidArgumentError("No benchmark with the given uid could be found") - with open(bmk_file, "r") as f: - data = yaml.safe_load(f) - - return data - @classmethod def get_models_uids(cls, benchmark_uid: int) -> List[int]: """Retrieves the list of models associated to the benchmark @@ -221,43 +94,6 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: ] return models_uids - def todict(self) -> dict: - """Dictionary representation of the benchmark instance - - Returns: - dict: Dictionary containing benchmark information - """ - return self.extended_dict() - - def write(self) -> str: - """Writes the benchmark into disk - - Args: - filename (str, optional): name of the file. Defaults to config.benchmarks_filename. - - Returns: - str: path to the created benchmark file - """ - data = self.todict() - bmk_file = os.path.join(self.path, config.benchmarks_filename) - if not os.path.exists(bmk_file): - os.makedirs(self.path, exist_ok=True) - with open(bmk_file, "w") as f: - yaml.dump(data, f) - return bmk_file - - def upload(self): - """Uploads a benchmark to the server - - Args: - comms (Comms): communications entity to submit through - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test benchmarks.") - body = self.todict() - updated_body = config.comms.upload_benchmark(body) - return updated_body - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 589bc5a0b..fd0446194 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -1,7 +1,7 @@ import os import yaml import logging -from typing import List, Dict, Optional, Union +from typing import Dict, Optional, Union from pydantic import Field from pathlib import Path @@ -12,21 +12,15 @@ generate_tmp_path, spawn_and_kill, ) -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - ExecutionError, - InvalidEntityError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError import medperf.config as config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data -class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Cube(Entity, MedperfSchema, DeployableSchema): """ Class representing an MLCube Container @@ -48,6 +42,26 @@ class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): metadata: dict = {} user_metadata: dict = {} + @staticmethod + def get_type(): + return "cube" + + @staticmethod + def get_storage_path(): + return config.cubes_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_cube_metadata + + @staticmethod + def get_metadata_filename(): + return config.cube_metadata_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_mlcube + def __init__(self, *args, **kwargs): """Creates a Cube instance @@ -57,59 +71,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = self.name - path = config.cubes_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - # NOTE: maybe have these as @property, to have the same entity reusable - # before and after submission - self.path = path - self.cube_path = os.path.join(path, config.cube_filename) + self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: - self.params_path = os.path.join(path, config.params_filename) - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: - """Class method for retrieving all retrievable MLCubes - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Cube]: List containing all cubes - """ - logging.info("Retrieving all cubes") - cubes = [] - if not local_only: - cubes = cls.__remote_all(filters=filters) - - remote_uids = set([cube.id for cube in cubes]) - - local_cubes = cls.__local_all() - - cubes += [cube for cube in local_cubes if cube.id not in remote_uids] - - return cubes - - @classmethod - def __remote_all(cls, filters: dict) -> List["Cube"]: - cubes = [] - - try: - comms_fn = cls.__remote_prefilter(filters) - cubes_meta = comms_fn() - cubes = [cls(**meta) for meta in cubes_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all cubes from the server" - logging.warning(msg) - - return cubes + self.params_path = os.path.join(self.path, config.params_filename) @classmethod - def __remote_prefilter(cls, filters: dict): + def _Entity__remote_prefilter(cls, filters: dict): """Applies filtering logic that must be done before retrieving remote entities Args: @@ -124,25 +92,6 @@ def __remote_prefilter(cls, filters: dict): return comms_fn - @classmethod - def __local_all(cls) -> List["Cube"]: - cubes = [] - cubes_folder = config.cubes_folder - try: - uids = next(os.walk(cubes_folder))[1] - logging.debug(f"Local cubes found: {uids}") - except StopIteration: - msg = "Couldn't iterate over cubes directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - cube = cls(**meta) - cubes.append(cube) - - return cubes - @classmethod def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": """Retrieves and creates a Cube instance from the comms. If cube already exists @@ -155,36 +104,12 @@ def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": Cube : a Cube instance with the retrieved data. """ - if not str(cube_uid).isdigit() or local_only: - cube = cls.__local_get(cube_uid) - else: - try: - cube = cls.__remote_get(cube_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting MLCube {cube_uid} from comms failed") - logging.info(f"Retrieving MLCube {cube_uid} from local storage") - cube = cls.__local_get(cube_uid) - + cube = super().get(cube_uid, local_only) if not cube.is_valid: raise InvalidEntityError("The requested MLCube is marked as INVALID.") cube.download_config_files() return cube - @classmethod - def __remote_get(cls, cube_uid: int) -> "Cube": - logging.debug(f"Retrieving mlcube {cube_uid} remotely") - meta = config.comms.get_cube_metadata(cube_uid) - cube = cls(**meta) - cube.write() - return cube - - @classmethod - def __local_get(cls, cube_uid: Union[str, int]) -> "Cube": - logging.debug(f"Retrieving cube {cube_uid} locally") - local_meta = cls.__get_local_dict(cube_uid) - cube = cls(**local_meta) - return cube - def download_mlcube(self): url = self.git_mlcube_url path, file_hash = resources.get_cube(url, self.path, self.mlcube_hash) @@ -430,36 +355,6 @@ def get_config(self, identifier): return cube - def todict(self) -> Dict: - return self.extended_dict() - - def write(self): - cube_loc = str(Path(self.cube_path).parent) - meta_file = os.path.join(cube_loc, config.cube_metadata_filename) - os.makedirs(cube_loc, exist_ok=True) - with open(meta_file, "w") as f: - yaml.dump(self.todict(), f) - return meta_file - - def upload(self): - if self.for_test: - raise InvalidArgumentError("Cannot upload test mlcubes.") - cube_dict = self.todict() - updated_cube_dict = config.comms.upload_mlcube(cube_dict) - return updated_cube_dict - - @classmethod - def __get_local_dict(cls, uid): - cubes_folder = config.cubes_folder - meta_file = os.path.join(cubes_folder, str(uid), config.cube_metadata_filename) - if not os.path.exists(meta_file): - raise InvalidArgumentError( - "The requested mlcube information could not be found locally" - ) - with open(meta_file, "r") as f: - meta = yaml.safe_load(f) - return meta - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 4c210431f..f50e8d680 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,22 +1,17 @@ import os import yaml -import logging from pydantic import Field, validator -from typing import List, Optional, Union +from typing import Optional, Union from medperf.utils import remove_path -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - MedperfException, - CommunicationRetrievalError, -) + import medperf.config as config from medperf.account_management import get_medperf_user_data -class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Dataset(Entity, MedperfSchema, DeployableSchema): """ Class representing a Dataset @@ -37,6 +32,26 @@ class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): report: dict = {} submitted_as_prepared: bool + @staticmethod + def get_type(): + return "dataset" + + @staticmethod + def get_storage_path(): + return config.datasets_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_dataset + + @staticmethod + def get_metadata_filename(): + return config.reg_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_dataset + @validator("data_preparation_mlcube", pre=True, always=True) def check_data_preparation_mlcube(cls, v, *, values, **kwargs): if not isinstance(v, int) and not values["for_test"]: @@ -48,13 +63,6 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - path = config.datasets_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) @@ -86,48 +94,8 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) - def todict(self): - return self.extended_dict() - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: - """Gets and creates instances of all the locally prepared datasets - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Dataset]: a list of Dataset instances. - """ - logging.info("Retrieving all datasets") - dsets = [] - if not local_only: - dsets = cls.__remote_all(filters=filters) - - remote_uids = set([dset.id for dset in dsets]) - - local_dsets = cls.__local_all() - - dsets += [dset for dset in local_dsets if dset.id not in remote_uids] - - return dsets - - @classmethod - def __remote_all(cls, filters: dict) -> List["Dataset"]: - dsets = [] - try: - comms_fn = cls.__remote_prefilter(filters) - dsets_meta = comms_fn() - dsets = [cls(**meta) for meta in dsets_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all datasets from the server" - logging.warning(msg) - - return dsets - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -149,111 +117,6 @@ def func(): return comms_fn - @classmethod - def __local_all(cls) -> List["Dataset"]: - dsets = [] - datasets_folder = config.datasets_folder - try: - uids = next(os.walk(datasets_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - dset = cls(**local_meta) - dsets.append(dset) - - return dsets - - @classmethod - def get(cls, dset_uid: Union[str, int], local_only: bool = False) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - if not str(dset_uid).isdigit() or local_only: - return cls.__local_get(dset_uid) - - try: - return cls.__remote_get(dset_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Dataset {dset_uid} from comms failed") - logging.info(f"Looking for dataset {dset_uid} locally") - return cls.__local_get(dset_uid) - - @classmethod - def __remote_get(cls, dset_uid: int) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} remotely") - meta = config.comms.get_dataset(dset_uid) - dataset = cls(**meta) - dataset.write() - return dataset - - @classmethod - def __local_get(cls, dset_uid: Union[str, int]) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} locally") - local_meta = cls.__get_local_dict(dset_uid) - dataset = cls(**local_meta) - return dataset - - def write(self): - logging.info(f"Updating registration information for dataset: {self.id}") - logging.debug(f"registration information: {self.todict()}") - regfile = os.path.join(self.path, config.reg_file) - os.makedirs(self.path, exist_ok=True) - with open(regfile, "w") as f: - yaml.dump(self.todict(), f) - return regfile - - def upload(self): - """Uploads the registration information to the comms. - - Args: - comms (Comms): Instance of the comms interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test datasets.") - dataset_dict = self.todict() - updated_dataset_dict = config.comms.upload_dataset(dataset_dict) - return updated_dataset_dict - - @classmethod - def __get_local_dict(cls, data_uid): - dataset_path = os.path.join(config.datasets_folder, str(data_uid)) - regfile = os.path.join(dataset_path, config.reg_file) - if not os.path.exists(regfile): - raise InvalidArgumentError( - "The requested dataset information could not be found locally" - ) - with open(regfile, "r") as f: - reg = yaml.safe_load(f) - return reg - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..7a5f0b5ef 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,77 +1,215 @@ from typing import List, Dict, Union -from abc import ABC, abstractmethod +from abc import ABC +import logging +import os +import yaml +from medperf.exceptions import MedperfException, InvalidArgumentError +from medperf.entities.schemas import MedperfBaseSchema -class Entity(ABC): - @abstractmethod - def all( - cls, local_only: bool = False, comms_func: callable = None - ) -> List["Entity"]: +class Entity(MedperfBaseSchema, ABC): + @staticmethod + def get_type(): + raise NotImplementedError() + + @staticmethod + def get_storage_path(): + raise NotImplementedError() + + @staticmethod + def get_comms_retriever(): + raise NotImplementedError() + + @staticmethod + def get_metadata_filename(): + raise NotImplementedError() + + @staticmethod + def get_comms_uploader(): + raise NotImplementedError() + + @property + def identifier(self): + return self.id or self.generated_uid + + @property + def is_registered(self): + return self.id is not None + + @property + def path(self): + storage_path = self.get_storage_path() + return os.path.join(storage_path, str(self.identifier)) + + @classmethod + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: """Gets a list of all instances of the respective entity. - Wether the list is local or remote depends on the implementation. + Whether the list is local or remote depends on the implementation. Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - comms_func (callable, optional): Function to use to retrieve remote entities. - If not provided, will use the default entrypoint. + unregistered (bool, optional): Wether to retrieve only unregistered local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + Returns: List[Entity]: a list of entities. """ + logging.info(f"Retrieving all {cls.get_type()} entities") + if unregistered: + if filters: + raise InvalidArgumentError( + "Filtering is not supported for unregistered entities" + ) + return cls.__unregistered_all() + return cls.__remote_all(filters=filters) + + @classmethod + def __remote_all(cls, filters: dict) -> List["Entity"]: + comms_fn = cls.__remote_prefilter(filters) + entity_meta = comms_fn() + entities = [cls(**meta) for meta in entity_meta] + return entities + + @classmethod + def __unregistered_all(cls) -> List["Entity"]: + entities = [] + storage_path = cls.get_storage_path() + try: + uids = next(os.walk(storage_path))[1] + except StopIteration: + msg = f"Couldn't iterate over the {cls.get_type()} storage" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + if uid.isdigit(): + continue + meta = cls.__get_local_dict(uid) + entity = cls(**meta) + entities.append(entity) + + return entities + + @classmethod + def __remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + raise NotImplementedError - @abstractmethod - def get(cls, uid: Union[str, int]) -> "Entity": + @classmethod + def get(cls, uid: Union[str, int], local_only: bool = False) -> "Entity": """Gets an instance of the respective entity. Wether this requires only local read or remote calls depends on the implementation. Args: uid (str): Unique Identifier to retrieve the entity + local_only (bool): If True, the entity will be retrieved locally Returns: Entity: Entity Instance associated to the UID """ - @abstractmethod - def todict(self) -> Dict: - """Dictionary representation of the entity + if not str(uid).isdigit() or local_only: + return cls.__local_get(uid) + return cls.__remote_get(uid) + + @classmethod + def __remote_get(cls, uid: int) -> "Entity": + """Retrieves and creates an entity instance from the comms instance. + + Args: + uid (int): server UID of the entity Returns: - Dict: Dictionary containing information about the entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} remotely") + comms_func = cls.get_comms_retriever() + entity_dict = comms_func(uid) + entity = cls(**entity_dict) + entity.write() + return entity - @abstractmethod - def write(self) -> str: - """Writes the entity to the local storage + @classmethod + def __local_get(cls, uid: Union[str, int]) -> "Entity": + """Retrieves and creates an entity instance from the local storage. + + Args: + uid (str|int): UID of the entity Returns: - str: Path to the stored entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} locally") + entity_dict = cls.__get_local_dict(uid) + entity = cls(**entity_dict) + return entity - @abstractmethod - def display_dict(self) -> dict: - """Returns a dictionary of entity properties that can be displayed - to a user interface using a verbose name of the property rather than - the internal names + @classmethod + def __get_local_dict(cls, uid: Union[str, int]) -> dict: + """Retrieves a local entity information + + Args: + uid (str): uid of the local entity Returns: - dict: the display dictionary + dict: information of the entity """ + logging.info(f"Retrieving {cls.get_type()} {uid} from local storage") + storage_path = cls.get_storage_path() + metadata_filename = cls.get_metadata_filename() + bmk_file = os.path.join(storage_path, str(uid), metadata_filename) + if not os.path.exists(bmk_file): + raise InvalidArgumentError( + f"No {cls.get_type()} with the given uid could be found" + ) + with open(bmk_file, "r") as f: + data = yaml.safe_load(f) + + return data + + def write(self) -> str: + """Writes the entity to the local storage + Returns: + str: Path to the stored entity + """ + data = self.todict() + metadata_filename = self.get_metadata_filename() + entity_file = os.path.join(self.path, metadata_filename) + os.makedirs(self.path, exist_ok=True) + with open(entity_file, "w") as f: + yaml.dump(data, f) + return entity_file -class Uploadable: - @abstractmethod def upload(self) -> Dict: """Upload the entity-related information to the communication's interface Returns: Dict: Dictionary with the updated entity information """ + if self.for_test: + raise InvalidArgumentError( + f"This test {self.get_type()} cannot be uploaded." + ) + body = self.todict() + comms_func = self.get_comms_uploader() + updated_body = comms_func(body) + return updated_body - @property - def identifier(self): - return self.id or self.generated_uid + def display_dict(self) -> dict: + """Returns a dictionary of entity properties that can be displayed + to a user interface using a verbose name of the property rather than + the internal names - @property - def is_registered(self): - return self.id is not None + Returns: + dict: the display dictionary + """ + raise NotImplementedError diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index c76f09894..65147e558 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -1,16 +1,11 @@ import hashlib -import os -import yaml -import logging from typing import List, Union, Optional -from medperf.entities.schemas import MedperfBaseSchema import medperf.config as config -from medperf.exceptions import InvalidArgumentError from medperf.entities.interface import Entity -class TestReport(Entity, MedperfBaseSchema): +class TestReport(Entity): """ Class representing a compatibility test report entry @@ -35,11 +30,23 @@ class TestReport(Entity, MedperfBaseSchema): data_evaluator_mlcube: Union[int, str] results: Optional[dict] + @staticmethod + def get_type(): + return "report" + + @staticmethod + def get_storage_path(): + return config.tests_folder + + @staticmethod + def get_metadata_filename(): + return config.test_report_file + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.id = None + self.for_test = True self.generated_uid = self.__generate_uid() - path = config.tests_folder - self.path = os.path.join(path, self.generated_uid) def __generate_uid(self): """A helper that generates a unique hash for a test report.""" @@ -52,71 +59,14 @@ def set_results(self, results): self.results = results @classmethod - def all( - cls, local_only: bool = False, mine_only: bool = False - ) -> List["TestReport"]: - """Gets and creates instances of test reports. - Arguments are only specified for compatibility with - `Entity.List` and `Entity.View`, but they don't contribute to - the logic. - - Returns: - List[TestReport]: List containing all test reports - """ - logging.info("Retrieving all reports") - reports = [] - tests_folder = config.tests_folder - try: - uids = next(os.walk(tests_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the tests directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - report = cls(**local_meta) - reports.append(report) - - return reports - - @classmethod - def get(cls, report_uid: str) -> "TestReport": - """Retrieves and creates a TestReport instance obtained the user's machine - - Args: - report_uid (str): UID of the TestReport instance - - Returns: - TestReport: Specified TestReport instance - """ - logging.debug(f"Retrieving report {report_uid}") - report_dict = cls.__get_local_dict(report_uid) - report = cls(**report_dict) - report.write() - return report - - def todict(self): - return self.extended_dict() - - def write(self): - report_file = os.path.join(self.path, config.test_report_file) - os.makedirs(self.path, exist_ok=True) - with open(report_file, "w") as f: - yaml.dump(self.todict(), f) - return report_file + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: + assert unregistered, "Reports are only unregistered" + assert filters == {}, "Reports cannot be filtered" + return super().all(unregistered=True, filters={}) @classmethod - def __get_local_dict(cls, local_uid): - report_path = os.path.join(config.tests_folder, str(local_uid)) - report_file = os.path.join(report_path, config.test_report_file) - if not os.path.exists(report_file): - raise InvalidArgumentError( - f"The requested report {local_uid} could not be retrieved" - ) - with open(report_file, "r") as f: - report_info = yaml.safe_load(f) - return report_info + def get(cls, report_uid: str, local_only: bool = False) -> "TestReport": + return super().get(report_uid, local_only=True) def display_dict(self): if self.data_path: diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index c82add87b..af4098521 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,16 +1,10 @@ -import os -import yaml -import logging -from typing import List, Union - -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema import medperf.config as config -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.account_management import get_medperf_user_data -class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): +class Result(Entity, MedperfSchema, ApprovableSchema): """ Class representing a Result entry @@ -28,59 +22,34 @@ class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): metadata: dict = {} user_metadata: dict = {} - def __init__(self, *args, **kwargs): - """Creates a new result instance""" - super().__init__(*args, **kwargs) - - self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" - path = config.results_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: - """Gets and creates instances of all the user's results - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Result]: List containing all results - """ - logging.info("Retrieving all results") - results = [] - if not local_only: - results = cls.__remote_all(filters=filters) - - remote_uids = set([result.id for result in results]) + @staticmethod + def get_type(): + return "result" - local_results = cls.__local_all() + @staticmethod + def get_storage_path(): + return config.results_folder - results += [res for res in local_results if res.id not in remote_uids] + @staticmethod + def get_comms_retriever(): + return config.comms.get_result - return results + @staticmethod + def get_metadata_filename(): + return config.results_info_file - @classmethod - def __remote_all(cls, filters: dict) -> List["Result"]: - results = [] + @staticmethod + def get_comms_uploader(): + return config.comms.upload_result - try: - comms_fn = cls.__remote_prefilter(filters) - results_meta = comms_fn() - results = [cls(**meta) for meta in results_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all results from the server" - logging.warning(msg) + def __init__(self, *args, **kwargs): + """Creates a new result instance""" + super().__init__(*args, **kwargs) - return results + self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,113 +73,6 @@ def get_benchmark_results(): return comms_fn - @classmethod - def __local_all(cls) -> List["Result"]: - results = [] - results_folder = config.results_folder - try: - uids = next(os.walk(results_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - result = cls(**local_meta) - results.append(result) - - return results - - @classmethod - def get(cls, result_uid: Union[str, int], local_only: bool = False) -> "Result": - """Retrieves and creates a Result instance obtained from the platform. - If the result instance already exists in the user's machine, it loads - the local instance - - Args: - result_uid (str): UID of the Result instance - - Returns: - Result: Specified Result instance - """ - if not str(result_uid).isdigit() or local_only: - return cls.__local_get(result_uid) - - try: - return cls.__remote_get(result_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Result {result_uid} from comms failed") - logging.info(f"Looking for result {result_uid} locally") - return cls.__local_get(result_uid) - - @classmethod - def __remote_get(cls, result_uid: int) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} remotely") - meta = config.comms.get_result(result_uid) - result = cls(**meta) - result.write() - return result - - @classmethod - def __local_get(cls, result_uid: Union[str, int]) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} locally") - local_meta = cls.__get_local_dict(result_uid) - result = cls(**local_meta) - return result - - def todict(self): - return self.extended_dict() - - def upload(self): - """Uploads the results to the comms - - Args: - comms (Comms): Instance of the communications interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test results.") - results_info = self.todict() - updated_results_info = config.comms.upload_result(results_info) - return updated_results_info - - def write(self): - result_file = os.path.join(self.path, config.results_info_file) - os.makedirs(self.path, exist_ok=True) - with open(result_file, "w") as f: - yaml.dump(self.todict(), f) - return result_file - - @classmethod - def __get_local_dict(cls, local_uid): - result_path = os.path.join(config.results_folder, str(local_uid)) - result_file = os.path.join(result_path, config.results_info_file) - if not os.path.exists(result_file): - raise InvalidArgumentError( - f"The requested result {local_uid} could not be retrieved" - ) - with open(result_file, "r") as f: - results_info = yaml.safe_load(f) - return results_info - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index 0e7a54291..cac3d3a01 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -46,7 +46,7 @@ def dict(self, *args, **kwargs) -> dict: out_dict = {k: v for k, v in model_dict.items() if k in valid_fields} return out_dict - def extended_dict(self) -> dict: + def todict(self) -> dict: """Dictionary containing both original and alias fields Returns: @@ -74,7 +74,7 @@ class Config: use_enum_values = True -class MedperfSchema(MedperfBaseSchema): +class MedperfSchema(BaseModel): for_test: bool = False id: Optional[int] name: str = Field(..., max_length=64) diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 74299c77e..c69544781 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -57,6 +57,9 @@ def mock_result_all(mocker, state_variables): TestResult(benchmark=triplet[0], model=triplet[1], dataset=triplet[2]) for triplet in cached_results_triplets ] + mocker.patch( + PATCH_EXECUTION.format("get_medperf_user_data", return_value={"id": 1}) + ) mocker.patch(PATCH_EXECUTION.format("Result.all"), return_value=results) diff --git a/cli/medperf/tests/commands/test_list.py b/cli/medperf/tests/commands/test_list.py index 1c2dc3267..ce7035960 100644 --- a/cli/medperf/tests/commands/test_list.py +++ b/cli/medperf/tests/commands/test_list.py @@ -47,18 +47,18 @@ def set_common_attributes(self, setup): self.state_variables = state_variables self.spies = spies - @pytest.mark.parametrize("local_only", [False, True]) + @pytest.mark.parametrize("unregistered", [False, True]) @pytest.mark.parametrize("mine_only", [False, True]) - def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): + def test_entity_all_is_called_properly(self, mocker, unregistered, mine_only): # Arrange filters = {"owner": 1} if mine_only else {} # Act - EntityList.run(Entity, [], local_only, mine_only) + EntityList.run(Entity, [], unregistered, mine_only) # Assert self.spies["all"].assert_called_once_with( - local_only=local_only, filters=filters + unregistered=unregistered, filters=filters ) @pytest.mark.parametrize("fields", [["UID", "MLCube"]]) diff --git a/cli/medperf/tests/commands/test_view.py b/cli/medperf/tests/commands/test_view.py index a2dddfeda..0ffe0fb13 100644 --- a/cli/medperf/tests/commands/test_view.py +++ b/cli/medperf/tests/commands/test_view.py @@ -1,143 +1,86 @@ import pytest -import yaml -import json from medperf.entities.interface import Entity -from medperf.exceptions import InvalidArgumentError from medperf.commands.view import EntityView - -def expected_output(entities, format): - if isinstance(entities, list): - data = [entity.todict() for entity in entities] - else: - data = entities.todict() - - if format == "yaml": - return yaml.dump(data) - if format == "json": - return json.dumps(data) - - -def generate_entity(id, mocker): - entity = mocker.create_autospec(spec=Entity) - mocker.patch.object(entity, "todict", return_value={"id": id}) - return entity +PATCH_VIEW = "medperf.commands.view.{}" @pytest.fixture -def ui_spy(mocker, ui): - return mocker.patch.object(ui, "print") +def entity(mocker): + return mocker.create_autospec(Entity) -@pytest.fixture( - params=[{"local": ["1", "2", "3"], "remote": ["4", "5", "6"], "user": ["4"]}] -) -def setup(request, mocker): - local_ids = request.param.get("local", []) - remote_ids = request.param.get("remote", []) - user_ids = request.param.get("user", []) - all_ids = list(set(local_ids + remote_ids + user_ids)) - - local_entities = [generate_entity(id, mocker) for id in local_ids] - remote_entities = [generate_entity(id, mocker) for id in remote_ids] - user_entities = [generate_entity(id, mocker) for id in user_ids] - all_entities = list(set(local_entities + remote_entities + user_entities)) - - def mock_all(filters={}, local_only=False): - if "owner" in filters: - return user_entities - if local_only: - return local_entities - return all_entities - - def mock_get(entity_id): - if entity_id in all_ids: - return generate_entity(entity_id, mocker) - else: - raise InvalidArgumentError - - mocker.patch("medperf.commands.view.get_medperf_user_data", return_value={"id": 1}) - mocker.patch.object(Entity, "all", side_effect=mock_all) - mocker.patch.object(Entity, "get", side_effect=mock_get) - - return local_entities, remote_entities, user_entities, all_entities - - -class TestViewEntityID: - def test_view_displays_entity_if_given(self, mocker, setup, ui_spy): - # Arrange - entity_id = "1" - entity = generate_entity(entity_id, mocker) - output = expected_output(entity, "yaml") - - # Act - EntityView.run(entity_id, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_all_if_no_id(self, setup, ui_spy): - # Arrange - *_, entities = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - -class TestViewFilteredEntities: - def test_view_displays_local_entities(self, setup, ui_spy): - # Arrange - entities, *_ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, local_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_user_entities(self, setup, ui_spy): - # Arrange - *_, entities, _ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, mine_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - -@pytest.mark.parametrize("entity_id", ["4", None]) -@pytest.mark.parametrize("format", ["yaml", "json"]) -class TestViewOutput: - @pytest.fixture - def output(self, setup, mocker, entity_id, format): - if entity_id is None: - *_, entities = setup - return expected_output(entities, format) - else: - entity = generate_entity(entity_id, mocker) - return expected_output(entity, format) - - def test_view_displays_specified_format(self, entity_id, output, ui_spy, format): - # Act - EntityView.run(entity_id, Entity, format=format) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_stores_specified_format(self, entity_id, output, format, fs): - # Arrange - filename = "file.txt" - - # Act - EntityView.run(entity_id, Entity, format=format, output=filename) - - # Assert - contents = open(filename, "r").read() - assert contents == output +@pytest.fixture +def entity_view(mocker): + view_class = EntityView(None, Entity, "", "", "", "") + return view_class + + +def test_prepare_with_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = 1 + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + get_spy.assert_called_once_with(1) + all_spy.assert_not_called() + assert not isinstance(entity_view.data, list) + + +def test_prepare_with_no_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once() + get_spy.assert_not_called() + assert isinstance(entity_view.data, list) + + +@pytest.mark.parametrize("unregistered", [False, True]) +def test_prepare_with_no_id_calls_all_with_unregistered_properly( + mocker, entity_view, entity, unregistered +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = unregistered + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=unregistered, filters={}) + + +@pytest.mark.parametrize("filters", [{}, {"f1": "v1"}]) +@pytest.mark.parametrize("mine_only", [False, True]) +def test_prepare_with_no_id_calls_all_with_proper_filters( + mocker, entity_view, entity, filters, mine_only +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = False + entity_view.filters = filters + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + mocker.patch(PATCH_VIEW.format("get_medperf_user_data"), return_value={"id": 1}) + if mine_only: + filters["owner"] = 1 + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=False, filters=filters) diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 3f1fde2e2..c36771e12 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -9,8 +9,9 @@ @pytest.fixture( params={ - "local": [1, 2, 3], - "remote": [4, 5, 6], + "unregistered": ["b1", "b2"], + "local": ["b1", "b2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], "user": [4], "models": [10, 11], } diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index b82b9a0e8..51234f6e3 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -24,7 +24,14 @@ } -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["c1", "c2"], + "local": ["c1", "c2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, fs): local_ents = request.param.get("local", []) remote_ents = request.param.get("remote", []) @@ -282,7 +289,9 @@ def test_run_stops_execution_if_child_fails(self, mocker, setup, task): cube.run(task) -@pytest.mark.parametrize("setup", [{"local": [DEFAULT_CUBE]}], indirect=True) +@pytest.mark.parametrize( + "setup", [{"local": [DEFAULT_CUBE], "remote": [DEFAULT_CUBE]}], indirect=True +) @pytest.mark.parametrize("task", ["task"]) @pytest.mark.parametrize( "out_key,out_value", diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index c636b2c26..b9d309f39 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -15,7 +15,7 @@ setup_result_fs, setup_result_comms, ) -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError @pytest.fixture(params=[Benchmark, Cube, Dataset, Result]) @@ -23,7 +23,14 @@ def Implementation(request): return request.param -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, Implementation, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) @@ -54,39 +61,52 @@ def setup(request, mocker, comms, Implementation, fs): @pytest.mark.parametrize( "setup", - [{"local": [283, 17, 493], "remote": [283, 1, 2], "user": [2]}], + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 283], + "remote": [283, 1, 2], + "user": [2], + } + ], indirect=True, ) class TestAll: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): self.ids = setup + self.unregistered_ids = set(self.ids["unregistered"]) self.local_ids = set(self.ids["local"]) self.remote_ids = set(self.ids["remote"]) self.user_ids = set(self.ids["user"]) - def test_all_returns_all_remote_and_local(self, Implementation): - # Arrange - all_ids = self.local_ids.union(self.remote_ids) - + def test_all_returns_all_remote_by_default(self, Implementation): # Act entities = Implementation.all() # Assert retrieved_ids = set([e.todict()["id"] for e in entities]) - assert all_ids == retrieved_ids + assert self.remote_ids == retrieved_ids - def test_all_local_only_returns_all_local(self, Implementation): + def test_all_unregistered_returns_all_unregistered(self, Implementation): # Act - entities = Implementation.all(local_only=True) + entities = Implementation.all(unregistered=True) # Assert - retrieved_ids = set([e.todict()["id"] for e in entities]) - assert self.local_ids == retrieved_ids + retrieved_names = set([e.name for e in entities]) + assert self.unregistered_ids == retrieved_names @pytest.mark.parametrize( - "setup", [{"local": [78], "remote": [479, 42, 7, 1]}], indirect=True + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 479], + "remote": [479, 42, 7, 1], + } + ], + indirect=True, ) class TestGet: def test_get_retrieves_entity_from_server(self, Implementation, setup): @@ -99,30 +119,20 @@ def test_get_retrieves_entity_from_server(self, Implementation, setup): # Assert assert entity.todict()["id"] == id - def test_get_retrieves_entity_local_if_not_on_server(self, Implementation, setup): - # Arrange - id = setup["local"][0] - - # Act - entity = Implementation.get(id) - - # Assert - assert entity.todict()["id"] == id - def test_get_raises_error_if_nonexistent(self, Implementation, setup): # Arrange id = str(19283) # Act & Assert - with pytest.raises(InvalidArgumentError): + with pytest.raises(CommunicationRetrievalError): Implementation.get(id) -@pytest.mark.parametrize("setup", [{"local": [742]}], indirect=True) +@pytest.mark.parametrize("setup", [{"remote": [742]}], indirect=True) class TestToDict: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): - self.id = setup["local"][0] + self.id = setup["remote"][0] def test_todict_returns_dict_representation(self, Implementation): # Arrange @@ -147,7 +157,16 @@ def test_todict_can_recreate_object(self, Implementation): assert ent_dict == ent_copy_dict -@pytest.mark.parametrize("setup", [{"local": [36]}], indirect=True) +@pytest.mark.parametrize( + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2"], + } + ], + indirect=True, +) class TestUpload: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 522251ca7..19c3178e3 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -15,14 +15,17 @@ # Setup Benchmark def setup_benchmark_fs(ents, fs): - bmks_path = config.benchmarks_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) - bmk_contents = TestBenchmark(**ent) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + bmk_contents = TestBenchmark(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + bmk_contents = TestBenchmark(id=str(ent)) + else: + bmk_contents = TestBenchmark(id=None, name=ent) + bmk_contents.generated_uid = ent + + bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) @@ -30,7 +33,7 @@ def setup_benchmark_fs(ents, fs): cubes_ids = list(set(cubes_ids)) setup_cube_fs(cubes_ids, fs) try: - fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.dict())) + fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.todict())) except FileExistsError: pass @@ -51,17 +54,18 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Cube def setup_cube_fs(ents, fs): - cubes_path = config.cubes_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - meta_cube_file = os.path.join( - cubes_path, str(id), config.cube_metadata_filename - ) - cube = TestCube(**ent) - meta = cube.dict() + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + cube = TestCube(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + cube = TestCube(id=str(ent)) + else: + cube = TestCube(id=None, name=ent) + cube.generated_uid = ent + + meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) + meta = cube.todict() try: fs.create_file(meta_cube_file, contents=yaml.dump(meta)) except FileExistsError: @@ -124,18 +128,21 @@ def setup_cube_comms_downloads(mocker, fs): # Setup Dataset def setup_dset_fs(ents, fs): - dsets_path = config.datasets_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - reg_dset_file = os.path.join(dsets_path, str(id), config.reg_file) - dset_contents = TestDataset(**ent) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + dset_contents = TestDataset(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + dset_contents = TestDataset(id=str(ent)) + else: + dset_contents = TestDataset(id=None, name=ent) + dset_contents.generated_uid = ent + + reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube setup_cube_fs([cube_id], fs) try: - fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.dict())) + fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.todict())) except FileExistsError: pass @@ -155,22 +162,26 @@ def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Result def setup_result_fs(ents, fs): - results_path = config.results_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - result_file = os.path.join(results_path, str(id), config.results_info_file) - bmk_id = ent.get("benchmark", 1) - cube_id = ent.get("model", 1) - dataset_id = ent.get("dataset", 1) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + result_contents = TestResult(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + result_contents = TestResult(id=str(ent)) + else: + result_contents = TestResult(id=None, name=ent) + result_contents.generated_uid = ent + + result_file = os.path.join(result_contents.path, config.results_info_file) + bmk_id = result_contents.benchmark + cube_id = result_contents.model + dataset_id = result_contents.dataset setup_benchmark_fs([bmk_id], fs) setup_cube_fs([cube_id], fs) setup_dset_fs([dataset_id], fs) - result_contents = TestResult(**ent) + try: - fs.create_file(result_file, contents=yaml.dump(result_contents.dict())) + fs.create_file(result_file, contents=yaml.dump(result_contents.todict())) except FileExistsError: pass From 9684e3643951b1d85e47a031ac2430c0d1984b65 Mon Sep 17 00:00:00 2001 From: hasan7n <78664424+hasan7n@users.noreply.github.com> Date: Tue, 14 May 2024 23:48:37 +0200 Subject: [PATCH 02/18] Update cli/medperf/entities/interface.py Co-authored-by: Viacheslav Kukushkin --- cli/medperf/entities/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 7a5f0b5ef..f9342f21f 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -25,7 +25,7 @@ def get_metadata_filename(): raise NotImplementedError() @staticmethod - def get_comms_uploader(): + def get_comms_uploader() -> Callable[dict, dict]: raise NotImplementedError() @property From e807c0ec0b5cbb1c4b8c3d47125343130d49ffee Mon Sep 17 00:00:00 2001 From: hasan7n <78664424+hasan7n@users.noreply.github.com> Date: Tue, 14 May 2024 23:55:34 +0200 Subject: [PATCH 03/18] Apply suggestions from code review Co-authored-by: Viacheslav Kukushkin --- cli/medperf/commands/list.py | 2 +- cli/medperf/entities/interface.py | 2 +- cli/medperf/entities/report.py | 11 +++++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index b5d6226a4..ddf6aae6e 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -18,7 +18,7 @@ def run( Args: unregistered (bool, optional): Display only local unregistered results. Defaults to False. - mine_only (bool, optional): Display all current-user results. Defaults to False. + mine_only (bool, optional): Display all registered current-user results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ entity_list = EntityList( diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index f9342f21f..5233c6988 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -29,7 +29,7 @@ def get_comms_uploader() -> Callable[dict, dict]: raise NotImplementedError() @property - def identifier(self): + def identifier(self) -> Union[int, str]: return self.id or self.generated_uid @property diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index 65147e558..6e3128b1a 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -65,8 +65,15 @@ def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: return super().all(unregistered=True, filters={}) @classmethod - def get(cls, report_uid: str, local_only: bool = False) -> "TestReport": - return super().get(report_uid, local_only=True) + def get(cls, uid: str, local_only: bool = False) -> "TestReport": + """Gets an instance of the TestReport. ignores local_only inherited flag as TestReport is always a local entity. + Args: + uid (str): Report Unique Identifier + local_only (bool): ignored. Left for aligning with parent Entity class + Returns: + TestReport: Report Instance associated to the UID + """ + return super().get(uid, local_only=True) def display_dict(self): if self.data_path: From a87e3fddc09d1cdf35d671ca28f3a5243b7c3dec Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 04:53:51 +0200 Subject: [PATCH 04/18] update outdated result submission code --- cli/medperf/commands/result/submit.py | 29 +++++++++++++-------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/cli/medperf/commands/result/submit.py b/cli/medperf/commands/result/submit.py index 15649ee04..b69a596ce 100644 --- a/cli/medperf/commands/result/submit.py +++ b/cli/medperf/commands/result/submit.py @@ -3,7 +3,6 @@ from medperf.exceptions import CleanExit from medperf.utils import remove_path, dict_pretty_print, approval_prompt from medperf.entities.result import Result -from medperf.enums import Status from medperf import config @@ -11,6 +10,7 @@ class ResultSubmission: @classmethod def run(cls, result_uid, approved=False): sub = cls(result_uid, approved=approved) + sub.get_result() updated_result_dict = sub.upload_results() sub.to_permanent_path(updated_result_dict) sub.write(updated_result_dict) @@ -21,27 +21,26 @@ def __init__(self, result_uid, approved=False): self.ui = config.ui self.approved = approved - def request_approval(self, result): - if result.approval_status == Status.APPROVED: - return True + def get_result(self): + self.result = Result.get(self.result_uid) - dict_pretty_print(result.results) + def request_approval(self): + dict_pretty_print(self.result.results) self.ui.print("Above are the results generated by the model") approved = approval_prompt( - "Do you approve uploading the presented results to the MLCommons comms? [Y/n]" + "Do you approve uploading the presented results to the MedPerf? [Y/n]" ) return approved def upload_results(self): - result = Result.get(self.result_uid) - approved = self.approved or self.request_approval(result) + approved = self.approved or self.request_approval() if not approved: raise CleanExit("Results upload operation cancelled") - updated_result_dict = result.upload() + updated_result_dict = self.result.upload() return updated_result_dict def to_permanent_path(self, result_dict: dict): @@ -50,12 +49,12 @@ def to_permanent_path(self, result_dict: dict): Args: result_dict (dict): updated results dictionary """ - result = Result(**result_dict) - result_storage = config.results_folder - old_res_loc = os.path.join(result_storage, result.generated_uid) - new_res_loc = result.path - remove_path(new_res_loc) - os.rename(old_res_loc, new_res_loc) + + old_result_loc = self.result.path + updated_result = Result(**result_dict) + new_result_loc = updated_result.path + remove_path(new_result_loc) + os.rename(old_result_loc, new_result_loc) def write(self, updated_result_dict): result = Result(**updated_result_dict) From 2e242617423810b6bd5a2d04134bbed4349cf087 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 04:54:33 +0200 Subject: [PATCH 05/18] refactor schemas --- cli/medperf/entities/benchmark.py | 6 +++--- cli/medperf/entities/cube.py | 8 ++++---- cli/medperf/entities/dataset.py | 6 +++--- cli/medperf/entities/interface.py | 10 +++++----- cli/medperf/entities/report.py | 12 ++++++++++-- cli/medperf/entities/result.py | 6 +++--- cli/medperf/entities/schemas.py | 30 ++++++++++++++---------------- 7 files changed, 42 insertions(+), 36 deletions(-) diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 1d33efa95..e35849299 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -3,11 +3,11 @@ import medperf.config as config from medperf.entities.interface import Entity -from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema +from medperf.entities.schemas import ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class Benchmark(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" + self.local_id = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" @classmethod def _Entity__remote_prefilter(cls, filters: dict) -> callable: diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index f4cdf5280..61b71867f 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -13,14 +13,14 @@ spawn_and_kill, ) from medperf.entities.interface import Entity -from medperf.entities.schemas import MedperfSchema, DeployableSchema +from medperf.entities.schemas import DeployableSchema from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError import medperf.config as config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data -class Cube(Entity, MedperfSchema, DeployableSchema): +class Cube(Entity, DeployableSchema): """ Class representing an MLCube Container @@ -70,7 +70,7 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.generated_uid = self.name + self.local_id = self.name self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: @@ -245,7 +245,7 @@ def run( """ kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" - cmd += f" --mlcube=\"{self.cube_path}\" --task={task} --platform={config.platform} --network=none" + cmd += f' --mlcube="{self.cube_path}" --task={task} --platform={config.platform} --network=none' if config.gpus is not None: cmd += f" --gpus={config.gpus}" if read_protected_input: diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index f50e8d680..b65989182 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -5,13 +5,13 @@ from medperf.utils import remove_path from medperf.entities.interface import Entity -from medperf.entities.schemas import MedperfSchema, DeployableSchema +from medperf.entities.schemas import DeployableSchema import medperf.config as config from medperf.account_management import get_medperf_user_data -class Dataset(Entity, MedperfSchema, DeployableSchema): +class Dataset(Entity, DeployableSchema): """ Class representing a Dataset @@ -62,7 +62,7 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + self.local_id = self.generated_uid self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 5233c6988..8709a9589 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,13 +1,13 @@ -from typing import List, Dict, Union +from typing import List, Dict, Union, Callable from abc import ABC import logging import os import yaml from medperf.exceptions import MedperfException, InvalidArgumentError -from medperf.entities.schemas import MedperfBaseSchema +from medperf.entities.schemas import MedperfSchema -class Entity(MedperfBaseSchema, ABC): +class Entity(MedperfSchema, ABC): @staticmethod def get_type(): raise NotImplementedError() @@ -25,12 +25,12 @@ def get_metadata_filename(): raise NotImplementedError() @staticmethod - def get_comms_uploader() -> Callable[dict, dict]: + def get_comms_uploader() -> Callable[[dict], dict]: raise NotImplementedError() @property def identifier(self) -> Union[int, str]: - return self.id or self.generated_uid + return self.id or self.local_id @property def is_registered(self): diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index 6e3128b1a..7cf220e8e 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -18,8 +18,16 @@ class TestReport(Entity): - model cube - evaluator cube - results + + Note: This entity is only a local one, there is no TestReports on the server + However, we still use the same Entity interface used by other entities + in order to reduce repeated code. Consequently, we mocked a few methods + and attributes inherited from the Entity interface that are not relevant to + this entity, such as the `name` and `id` attributes, and such as + the `get` and `all` methods. """ + name: Optional[str] = "name" demo_dataset_url: Optional[str] demo_dataset_hash: Optional[str] data_path: Optional[str] @@ -46,7 +54,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.id = None self.for_test = True - self.generated_uid = self.__generate_uid() + self.local_id = self.__generate_uid() def __generate_uid(self): """A helper that generates a unique hash for a test report.""" @@ -84,7 +92,7 @@ def display_dict(self): data_source = f"{self.prepared_data_hash}" return { - "UID": self.generated_uid, + "UID": self.local_id, "Data Source": data_source, "Model": ( self.model if isinstance(self.model, int) else self.model[:27] + "..." diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index af4098521..f5cc5243b 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,10 +1,10 @@ from medperf.entities.interface import Entity -from medperf.entities.schemas import MedperfSchema, ApprovableSchema +from medperf.entities.schemas import ApprovableSchema import medperf.config as config from medperf.account_management import get_medperf_user_data -class Result(Entity, MedperfSchema, ApprovableSchema): +class Result(Entity, ApprovableSchema): """ Class representing a Result entry @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): """Creates a new result instance""" super().__init__(*args, **kwargs) - self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" + self.local_id = f"b{self.benchmark}m{self.model}d{self.dataset}" @classmethod def _Entity__remote_prefilter(cls, filters: dict) -> callable: diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index cac3d3a01..79926abd9 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -8,7 +8,15 @@ from medperf.utils import format_errors_dict -class MedperfBaseSchema(BaseModel): +class MedperfSchema(BaseModel): + for_test: bool = False + id: Optional[int] + name: str = Field(..., max_length=64) + owner: Optional[int] + is_valid: bool = True + created_at: Optional[datetime] + modified_at: Optional[datetime] + def __init__(self, *args, **kwargs): """Override the ValidationError procedure so we can format the error message in our desired way @@ -68,27 +76,17 @@ def empty_str_to_none(cls, v): return None return v - class Config: - allow_population_by_field_name = True - extra = "allow" - use_enum_values = True - - -class MedperfSchema(BaseModel): - for_test: bool = False - id: Optional[int] - name: str = Field(..., max_length=64) - owner: Optional[int] - is_valid: bool = True - created_at: Optional[datetime] - modified_at: Optional[datetime] - @validator("name", pre=True, always=True) def name_max_length(cls, v, *, values, **kwargs): if not values["for_test"] and len(v) > 20: raise ValueError("The name must have no more than 20 characters") return v + class Config: + allow_population_by_field_name = True + extra = "allow" + use_enum_values = True + class DeployableSchema(BaseModel): state: str = "DEVELOPMENT" From cc0dac0267b9106b0cb7e773bdc0176022b50ece Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 04:54:57 +0200 Subject: [PATCH 06/18] use local_id in place of generated uid for clarity --- cli/medperf/cli.py | 2 +- cli/medperf/commands/benchmark/submit.py | 2 +- cli/medperf/commands/compatibility_test/run.py | 2 +- cli/medperf/commands/compatibility_test/utils.py | 1 + cli/medperf/commands/dataset/set_operational.py | 1 + cli/medperf/commands/execution.py | 14 +++++++------- cli/medperf/commands/result/create.py | 2 +- 7 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 4fc7102c4..0910c3ed8 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -71,7 +71,7 @@ def execute( please run the command again with the --no-cache option.\n""" ) else: - ResultSubmission.run(result.generated_uid, approved=approval) + ResultSubmission.run(result.local_id, approved=approval) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/benchmark/submit.py b/cli/medperf/commands/benchmark/submit.py index ebace1880..05d1a0d10 100644 --- a/cli/medperf/commands/benchmark/submit.py +++ b/cli/medperf/commands/benchmark/submit.py @@ -79,7 +79,7 @@ def run_compatibility_test(self): self.ui.print("Running compatibility test") self.bmk.write() data_uid, results = CompatibilityTestExecution.run( - benchmark=self.bmk.generated_uid, + benchmark=self.bmk.local_id, no_cache=self.no_cache, skip_data_preparation_step=self.skip_data_preparation_step, ) diff --git a/cli/medperf/commands/compatibility_test/run.py b/cli/medperf/commands/compatibility_test/run.py index 2e3082849..f06603d57 100644 --- a/cli/medperf/commands/compatibility_test/run.py +++ b/cli/medperf/commands/compatibility_test/run.py @@ -239,7 +239,7 @@ def cached_results(self): """ if self.no_cache: return - uid = self.report.generated_uid + uid = self.report.local_id try: report = TestReport.get(uid) except InvalidArgumentError: diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index c56a57d41..3e1e4e26f 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -155,6 +155,7 @@ def create_test_dataset( remove_path(new_path) os.rename(old_path, new_path) dataset.generated_uid = new_generated_uid + dataset.local_id = new_generated_uid dataset.write() return new_generated_uid diff --git a/cli/medperf/commands/dataset/set_operational.py b/cli/medperf/commands/dataset/set_operational.py index 37758ddfe..985d0ce28 100644 --- a/cli/medperf/commands/dataset/set_operational.py +++ b/cli/medperf/commands/dataset/set_operational.py @@ -40,6 +40,7 @@ def generate_uids(self): generated_uid = get_folders_hash([prepared_data_path, prepared_labels_path]) self.dataset.input_data_hash = in_uid self.dataset.generated_uid = generated_uid + self.dataset.local_id = generated_uid # Not relevant, but for consistency def set_statistics(self): with open(self.dataset.statistics_path, "r") as f: diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index d8afb2244..85416fe96 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -47,12 +47,12 @@ def prepare(self): logging.debug(f"tmp results output: {self.results_path}") def __setup_logs_path(self): - model_uid = self.model.generated_uid - eval_uid = self.evaluator.generated_uid - data_hash = self.dataset.generated_uid + model_uid = self.model.local_id + eval_uid = self.evaluator.local_id + data_uid = self.dataset.local_id logs_path = os.path.join( - config.experiments_logs_folder, str(model_uid), str(data_hash) + config.experiments_logs_folder, str(model_uid), str(data_uid) ) os.makedirs(logs_path, exist_ok=True) model_logs_path = os.path.join(logs_path, "model.log") @@ -60,10 +60,10 @@ def __setup_logs_path(self): return model_logs_path, metrics_logs_path def __setup_predictions_path(self): - model_uid = self.model.generated_uid - data_hash = self.dataset.generated_uid + model_uid = self.model.local_id + data_uid = self.dataset.local_id preds_path = os.path.join( - config.predictions_folder, str(model_uid), str(data_hash) + config.predictions_folder, str(model_uid), str(data_uid) ) if os.path.exists(preds_path): msg = f"Found existing predictions for model {self.model.id} on dataset " diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 760dddc94..26d52fa2e 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -257,7 +257,7 @@ def print_summary(self): data_lists_for_display.append( [ experiment["model_uid"], - experiment["result"].generated_uid, + experiment["result"].local_id, experiment["result"].metadata["partial"], experiment["cached"], experiment["error"], From d808b782a66dc51ae9fd00dca3d184b2a58bee62 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 04:55:04 +0200 Subject: [PATCH 07/18] update tests --- .../tests/commands/benchmark/test_submit.py | 4 ++-- .../tests/commands/mlcube/test_submit.py | 2 +- .../tests/commands/result/test_submit.py | 1 + cli/medperf/tests/commands/test_execution.py | 18 +++++++++--------- cli/medperf/tests/entities/utils.py | 10 +++++----- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/cli/medperf/tests/commands/benchmark/test_submit.py b/cli/medperf/tests/commands/benchmark/test_submit.py index b00e1c5a8..7e2d5b23b 100644 --- a/cli/medperf/tests/commands/benchmark/test_submit.py +++ b/cli/medperf/tests/commands/benchmark/test_submit.py @@ -94,7 +94,7 @@ def test_run_compatibility_test_uses_expected_default_parameters(mocker, comms, # Assert comp_spy.assert_called_once_with( - benchmark=bmk.generated_uid, no_cache=True, skip_data_preparation_step=False + benchmark=bmk.local_id, no_cache=True, skip_data_preparation_step=False ) @@ -117,7 +117,7 @@ def test_run_compatibility_test_with_passed_parameters(mocker, force, skip, comm # Assert comp_spy.assert_called_once_with( - benchmark=bmk.generated_uid, no_cache=force, skip_data_preparation_step=skip + benchmark=bmk.local_id, no_cache=force, skip_data_preparation_step=skip ) diff --git a/cli/medperf/tests/commands/mlcube/test_submit.py b/cli/medperf/tests/commands/mlcube/test_submit.py index 630390205..a946c1fef 100644 --- a/cli/medperf/tests/commands/mlcube/test_submit.py +++ b/cli/medperf/tests/commands/mlcube/test_submit.py @@ -57,7 +57,7 @@ def test_to_permanent_path_renames_correctly(mocker, comms, ui, cube, uid): submission.cube = cube spy = mocker.patch("os.rename") mocker.patch("os.path.exists", return_value=False) - old_path = os.path.join(config.cubes_folder, cube.generated_uid) + old_path = os.path.join(config.cubes_folder, cube.local_id) new_path = os.path.join(config.cubes_folder, str(uid)) # Act submission.to_permanent_path({**cube.todict(), "id": uid}) diff --git a/cli/medperf/tests/commands/result/test_submit.py b/cli/medperf/tests/commands/result/test_submit.py index 10680fbe1..26b03fbcc 100644 --- a/cli/medperf/tests/commands/result/test_submit.py +++ b/cli/medperf/tests/commands/result/test_submit.py @@ -25,6 +25,7 @@ def submission(mocker, comms, ui, result, dataset): sub = ResultSubmission(1) mocker.patch(PATCH_SUBMISSION.format("Result"), return_value=result) mocker.patch(PATCH_SUBMISSION.format("Result.get"), return_value=result) + sub.get_result() return sub diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index 669d7dfd9..d50ca5d31 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -102,8 +102,8 @@ def test_failure_with_existing_predictions(mocker, setup, ignore_model_errors, f # Arrange preds_path = os.path.join( config.predictions_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, ) fs.create_dir(preds_path) @@ -149,22 +149,22 @@ def test_cube_run_are_called_properly(mocker, setup): # Arrange exp_preds_path = os.path.join( config.predictions_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, ) exp_model_logs_path = os.path.join( config.experiments_logs_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, "model.log", ) exp_metrics_logs_path = os.path.join( config.experiments_logs_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, - f"metrics_{INPUT_EVALUATOR.generated_uid}.log", + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, + f"metrics_{INPUT_EVALUATOR.local_id}.log", ) exp_model_call = call( diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 19c3178e3..264873f45 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -16,14 +16,14 @@ # Setup Benchmark def setup_benchmark_fs(ents, fs): for ent in ents: - # Assume we're passing ids, names, or dicts + # Assume we're passing ids, local_ids, or dicts if isinstance(ent, dict): bmk_contents = TestBenchmark(**ent) elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): bmk_contents = TestBenchmark(id=str(ent)) else: bmk_contents = TestBenchmark(id=None, name=ent) - bmk_contents.generated_uid = ent + bmk_contents.local_id = ent bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] @@ -62,7 +62,7 @@ def setup_cube_fs(ents, fs): cube = TestCube(id=str(ent)) else: cube = TestCube(id=None, name=ent) - cube.generated_uid = ent + cube.local_id = ent meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) meta = cube.todict() @@ -136,7 +136,7 @@ def setup_dset_fs(ents, fs): dset_contents = TestDataset(id=str(ent)) else: dset_contents = TestDataset(id=None, name=ent) - dset_contents.generated_uid = ent + dset_contents.local_id = ent reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube @@ -170,7 +170,7 @@ def setup_result_fs(ents, fs): result_contents = TestResult(id=str(ent)) else: result_contents = TestResult(id=None, name=ent) - result_contents.generated_uid = ent + result_contents.local_id = ent result_file = os.path.join(result_contents.path, config.results_info_file) bmk_id = result_contents.benchmark From d42e4d808dd12932b0c70bfc29cb5c9b6ab86d2e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 05:46:13 +0200 Subject: [PATCH 08/18] no need to complicate things --- cli/medperf/entities/benchmark.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index e35849299..e9f3117d0 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -58,7 +58,9 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.local_id = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" + @property + def local_id(self): + return self.name @classmethod def _Entity__remote_prefilter(cls, filters: dict) -> callable: From cd72637ce1415a58f688a0884df553df17825219 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 05:46:24 +0200 Subject: [PATCH 09/18] use dynamic local_id --- cli/medperf/commands/compatibility_test/utils.py | 1 - cli/medperf/commands/dataset/set_operational.py | 1 - cli/medperf/entities/cube.py | 5 ++++- cli/medperf/entities/dataset.py | 5 ++++- cli/medperf/entities/interface.py | 4 ++++ cli/medperf/entities/report.py | 4 ++-- cli/medperf/entities/result.py | 4 +++- cli/medperf/tests/entities/test_entity.py | 4 ++-- cli/medperf/tests/entities/utils.py | 8 ++------ 9 files changed, 21 insertions(+), 15 deletions(-) diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index 3e1e4e26f..c56a57d41 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -155,7 +155,6 @@ def create_test_dataset( remove_path(new_path) os.rename(old_path, new_path) dataset.generated_uid = new_generated_uid - dataset.local_id = new_generated_uid dataset.write() return new_generated_uid diff --git a/cli/medperf/commands/dataset/set_operational.py b/cli/medperf/commands/dataset/set_operational.py index 985d0ce28..37758ddfe 100644 --- a/cli/medperf/commands/dataset/set_operational.py +++ b/cli/medperf/commands/dataset/set_operational.py @@ -40,7 +40,6 @@ def generate_uids(self): generated_uid = get_folders_hash([prepared_data_path, prepared_labels_path]) self.dataset.input_data_hash = in_uid self.dataset.generated_uid = generated_uid - self.dataset.local_id = generated_uid # Not relevant, but for consistency def set_statistics(self): with open(self.dataset.statistics_path, "r") as f: diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 61b71867f..bc0b415ce 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -70,12 +70,15 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.local_id = self.name self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: self.params_path = os.path.join(self.path, config.params_filename) + @property + def local_id(self): + return self.name + @classmethod def _Entity__remote_prefilter(cls, filters: dict): """Applies filtering logic that must be done before retrieving remote entities diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index b65989182..f6999fa45 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -62,13 +62,16 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.local_id = self.generated_uid self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) self.metadata_path = os.path.join(self.path, config.metadata_folder) self.statistics_path = os.path.join(self.path, config.statistics_filename) + @property + def local_id(self): + return self.generated_uid + def set_raw_paths(self, raw_data_path: str, raw_labels_path: str): raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) data = {"data_path": raw_data_path, "labels_path": raw_labels_path} diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 8709a9589..356d65770 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -28,6 +28,10 @@ def get_metadata_filename(): def get_comms_uploader() -> Callable[[dict], dict]: raise NotImplementedError() + @property + def local_id(self) -> str: + raise NotImplementedError() + @property def identifier(self) -> Union[int, str]: return self.id or self.local_id diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index 7cf220e8e..a2488e11b 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -54,9 +54,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.id = None self.for_test = True - self.local_id = self.__generate_uid() - def __generate_uid(self): + @property + def local_id(self): """A helper that generates a unique hash for a test report.""" params = self.todict() del params["results"] diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index f5cc5243b..63e13ecb2 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -46,7 +46,9 @@ def __init__(self, *args, **kwargs): """Creates a new result instance""" super().__init__(*args, **kwargs) - self.local_id = f"b{self.benchmark}m{self.model}d{self.dataset}" + @property + def local_id(self): + return self.name @classmethod def _Entity__remote_prefilter(cls, filters: dict) -> callable: diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index b9d309f39..5f2d24b3a 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -93,8 +93,8 @@ def test_all_unregistered_returns_all_unregistered(self, Implementation): entities = Implementation.all(unregistered=True) # Assert - retrieved_names = set([e.name for e in entities]) - assert self.unregistered_ids == retrieved_names + retrieved_ids = set([e.local_id for e in entities]) + assert self.unregistered_ids == retrieved_ids @pytest.mark.parametrize( diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 264873f45..c3bde6feb 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -23,7 +23,6 @@ def setup_benchmark_fs(ents, fs): bmk_contents = TestBenchmark(id=str(ent)) else: bmk_contents = TestBenchmark(id=None, name=ent) - bmk_contents.local_id = ent bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] @@ -62,7 +61,6 @@ def setup_cube_fs(ents, fs): cube = TestCube(id=str(ent)) else: cube = TestCube(id=None, name=ent) - cube.local_id = ent meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) meta = cube.todict() @@ -129,14 +127,13 @@ def setup_cube_comms_downloads(mocker, fs): # Setup Dataset def setup_dset_fs(ents, fs): for ent in ents: - # Assume we're passing ids, names, or dicts + # Assume we're passing ids, generated_uids, or dicts if isinstance(ent, dict): dset_contents = TestDataset(**ent) elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): dset_contents = TestDataset(id=str(ent)) else: - dset_contents = TestDataset(id=None, name=ent) - dset_contents.local_id = ent + dset_contents = TestDataset(id=None, generated_uid=ent) reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube @@ -170,7 +167,6 @@ def setup_result_fs(ents, fs): result_contents = TestResult(id=str(ent)) else: result_contents = TestResult(id=None, name=ent) - result_contents.local_id = ent result_file = os.path.join(result_contents.path, config.results_info_file) bmk_id = result_contents.benchmark From 1c6b01ebfc4624eeb9827e1330b2c177df4251db Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 06:56:09 +0200 Subject: [PATCH 10/18] test some type annotations --- cli/medperf/commands/dataset/set_operational.py | 2 +- cli/medperf/commands/list.py | 15 ++++++++++++--- cli/medperf/commands/result/create.py | 2 +- cli/medperf/commands/view.py | 13 ++++++++++--- cli/medperf/entities/interface.py | 15 +++++++-------- cli/medperf/entities/report.py | 2 +- 6 files changed, 32 insertions(+), 17 deletions(-) diff --git a/cli/medperf/commands/dataset/set_operational.py b/cli/medperf/commands/dataset/set_operational.py index 37758ddfe..6aae2ab46 100644 --- a/cli/medperf/commands/dataset/set_operational.py +++ b/cli/medperf/commands/dataset/set_operational.py @@ -21,7 +21,7 @@ def run(cls, dataset_id: int, approved: bool = False): def __init__(self, dataset_id: int, approved: bool): self.ui = config.ui - self.dataset = Dataset.get(dataset_id) + self.dataset: Dataset = Dataset.get(dataset_id) self.approved = approved def validate(self): diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index ddf6aae6e..99236ac3f 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -1,3 +1,5 @@ +from typing import List, Type +from medperf.entities.interface import Entity from medperf.exceptions import InvalidArgumentError from tabulate import tabulate @@ -8,8 +10,8 @@ class EntityList: @staticmethod def run( - entity_class, - fields, + entity_class: Type[Entity], + fields: List[str], unregistered: bool = False, mine_only: bool = False, **kwargs, @@ -29,7 +31,14 @@ def run( entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, unregistered, mine_only, **kwargs): + def __init__( + self, + entity_class: Type[Entity], + fields: List[str], + unregistered: bool, + mine_only: bool, + **kwargs, + ): self.entity_class = entity_class self.fields = fields self.unregistered = unregistered diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 26d52fa2e..d156fb659 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -145,7 +145,7 @@ def __validate_models(self, benchmark_models): def load_cached_results(self): user_id = get_medperf_user_data()["id"] - results = Result.all(filters={"owner": user_id}) + results: List[Result] = Result.all(filters={"owner": user_id}) results += Result.all(unregistered=True) benchmark_dset_results = [ result diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index 8c2a4179f..d19aedec0 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -1,6 +1,6 @@ import yaml import json -from typing import Union +from typing import Union, Type from medperf import config from medperf.account_management import get_medperf_user_data @@ -12,7 +12,7 @@ class EntityView: @staticmethod def run( entity_id: Union[int, str], - entity_class: Entity, + entity_class: Type[Entity], format: str = "yaml", unregistered: bool = False, mine_only: bool = False, @@ -41,7 +41,14 @@ def run( entity_view.store() def __init__( - self, entity_id, entity_class, format, unregistered, mine_only, output, **kwargs + self, + entity_id: Union[int, str], + entity_class: Type[Entity], + format: str, + unregistered: bool, + mine_only: bool, + output: str, + **kwargs, ): self.entity_id = entity_id self.entity_class = entity_class diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 356d65770..ff28c7a4b 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -9,19 +9,19 @@ class Entity(MedperfSchema, ABC): @staticmethod - def get_type(): + def get_type() -> str: raise NotImplementedError() @staticmethod - def get_storage_path(): + def get_storage_path() -> str: raise NotImplementedError() @staticmethod - def get_comms_retriever(): + def get_comms_retriever() -> Callable[[int], dict]: raise NotImplementedError() @staticmethod - def get_metadata_filename(): + def get_metadata_filename() -> str: raise NotImplementedError() @staticmethod @@ -37,11 +37,11 @@ def identifier(self) -> Union[int, str]: return self.id or self.local_id @property - def is_registered(self): + def is_registered(self) -> bool: return self.id is not None @property - def path(self): + def path(self) -> str: storage_path = self.get_storage_path() return os.path.join(storage_path, str(self.identifier)) @@ -88,8 +88,7 @@ def __unregistered_all(cls) -> List["Entity"]: for uid in uids: if uid.isdigit(): continue - meta = cls.__get_local_dict(uid) - entity = cls(**meta) + entity = cls.__local_get(uid) entities.append(entity) return entities diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index a2488e11b..cefd168b3 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -67,7 +67,7 @@ def set_results(self, results): self.results = results @classmethod - def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["TestReport"]: assert unregistered, "Reports are only unregistered" assert filters == {}, "Reports cannot be filtered" return super().all(unregistered=True, filters={}) From 271ee8dcf4363e25d9947ae822b7f01eed4db7dc Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 16:01:01 +0200 Subject: [PATCH 11/18] modify __remote_prefilter --- cli/medperf/entities/benchmark.py | 4 ++-- cli/medperf/entities/cube.py | 4 ++-- cli/medperf/entities/dataset.py | 4 ++-- cli/medperf/entities/interface.py | 6 +++--- cli/medperf/entities/result.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index e9f3117d0..e03fcdb4f 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -62,8 +62,8 @@ def __init__(self, *args, **kwargs): def local_id(self): return self.name - @classmethod - def _Entity__remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index bc0b415ce..714342c53 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -79,8 +79,8 @@ def __init__(self, *args, **kwargs): def local_id(self): return self.name - @classmethod - def _Entity__remote_prefilter(cls, filters: dict): + @staticmethod + def remote_prefilter(filters: dict): """Applies filtering logic that must be done before retrieving remote entities Args: diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index f6999fa45..7f13c2185 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -97,8 +97,8 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) - @classmethod - def _Entity__remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index ff28c7a4b..1ac6d64fb 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -69,7 +69,7 @@ def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: @classmethod def __remote_all(cls, filters: dict) -> List["Entity"]: - comms_fn = cls.__remote_prefilter(filters) + comms_fn = cls.remote_prefilter(filters) entity_meta = comms_fn() entities = [cls(**meta) for meta in entity_meta] return entities @@ -93,8 +93,8 @@ def __unregistered_all(cls) -> List["Entity"]: return entities - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index 63e13ecb2..0e96d1feb 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -50,8 +50,8 @@ def __init__(self, *args, **kwargs): def local_id(self): return self.name - @classmethod - def _Entity__remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: From 88a99f174a161ee6641cd4c2eca4385a30c2f583 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 15 May 2024 16:06:33 +0200 Subject: [PATCH 12/18] rename outdated intermediate vars --- cli/medperf/entities/interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 1ac6d64fb..9824210ef 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -168,12 +168,12 @@ def __get_local_dict(cls, uid: Union[str, int]) -> dict: logging.info(f"Retrieving {cls.get_type()} {uid} from local storage") storage_path = cls.get_storage_path() metadata_filename = cls.get_metadata_filename() - bmk_file = os.path.join(storage_path, str(uid), metadata_filename) - if not os.path.exists(bmk_file): + entity_file = os.path.join(storage_path, str(uid), metadata_filename) + if not os.path.exists(entity_file): raise InvalidArgumentError( f"No {cls.get_type()} with the given uid could be found" ) - with open(bmk_file, "r") as f: + with open(entity_file, "r") as f: data = yaml.safe_load(f) return data From d4dd21253ad5218ed86ab0a91c75a933b947c67b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 16 May 2024 00:12:24 +0200 Subject: [PATCH 13/18] use TypeVar for type hints --- .../commands/dataset/set_operational.py | 2 +- cli/medperf/commands/result/create.py | 2 +- cli/medperf/entities/interface.py | 21 ++++++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/cli/medperf/commands/dataset/set_operational.py b/cli/medperf/commands/dataset/set_operational.py index 6aae2ab46..37758ddfe 100644 --- a/cli/medperf/commands/dataset/set_operational.py +++ b/cli/medperf/commands/dataset/set_operational.py @@ -21,7 +21,7 @@ def run(cls, dataset_id: int, approved: bool = False): def __init__(self, dataset_id: int, approved: bool): self.ui = config.ui - self.dataset: Dataset = Dataset.get(dataset_id) + self.dataset = Dataset.get(dataset_id) self.approved = approved def validate(self): diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index d156fb659..26d52fa2e 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -145,7 +145,7 @@ def __validate_models(self, benchmark_models): def load_cached_results(self): user_id = get_medperf_user_data()["id"] - results: List[Result] = Result.all(filters={"owner": user_id}) + results = Result.all(filters={"owner": user_id}) results += Result.all(unregistered=True) benchmark_dset_results = [ result diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 9824210ef..835fbdf22 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -5,6 +5,9 @@ import yaml from medperf.exceptions import MedperfException, InvalidArgumentError from medperf.entities.schemas import MedperfSchema +from typing import Type, TypeVar + +EntityType = TypeVar("EntityType", bound="Entity") class Entity(MedperfSchema, ABC): @@ -46,7 +49,9 @@ def path(self) -> str: return os.path.join(storage_path, str(self.identifier)) @classmethod - def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: + def all( + cls: Type[EntityType], unregistered: bool = False, filters: dict = {} + ) -> List[EntityType]: """Gets a list of all instances of the respective entity. Whether the list is local or remote depends on the implementation. @@ -68,14 +73,14 @@ def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: return cls.__remote_all(filters=filters) @classmethod - def __remote_all(cls, filters: dict) -> List["Entity"]: + def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]: comms_fn = cls.remote_prefilter(filters) entity_meta = comms_fn() entities = [cls(**meta) for meta in entity_meta] return entities @classmethod - def __unregistered_all(cls) -> List["Entity"]: + def __unregistered_all(cls: Type[EntityType]) -> List[EntityType]: entities = [] storage_path = cls.get_storage_path() try: @@ -106,7 +111,9 @@ def remote_prefilter(filters: dict) -> callable: raise NotImplementedError @classmethod - def get(cls, uid: Union[str, int], local_only: bool = False) -> "Entity": + def get( + cls: Type[EntityType], uid: Union[str, int], local_only: bool = False + ) -> EntityType: """Gets an instance of the respective entity. Wether this requires only local read or remote calls depends on the implementation. @@ -124,7 +131,7 @@ def get(cls, uid: Union[str, int], local_only: bool = False) -> "Entity": return cls.__remote_get(uid) @classmethod - def __remote_get(cls, uid: int) -> "Entity": + def __remote_get(cls: Type[EntityType], uid: int) -> EntityType: """Retrieves and creates an entity instance from the comms instance. Args: @@ -141,7 +148,7 @@ def __remote_get(cls, uid: int) -> "Entity": return entity @classmethod - def __local_get(cls, uid: Union[str, int]) -> "Entity": + def __local_get(cls: Type[EntityType], uid: Union[str, int]) -> EntityType: """Retrieves and creates an entity instance from the local storage. Args: @@ -156,7 +163,7 @@ def __local_get(cls, uid: Union[str, int]) -> "Entity": return entity @classmethod - def __get_local_dict(cls, uid: Union[str, int]) -> dict: + def __get_local_dict(cls: Type[EntityType], uid: Union[str, int]) -> dict: """Retrieves a local entity information Args: From 075782a30afb926a56757e0ae795b3f5e1e8af6d Mon Sep 17 00:00:00 2001 From: Viacheslav Kukushkin Date: Tue, 4 Jun 2024 15:02:32 +0300 Subject: [PATCH 14/18] Typo fix --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2f07c511a..550d281ca 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Inside this repo you can find all important pieces for running MedPerf. In its c If you use MedPerf, please cite our main paper: Karargyris, A., Umeton, R., Sheller, M.J. et al. Federated benchmarking of medical artificial intelligence with MedPerf. *Nature Machine Intelligence* **5**, 799–810 (2023). [https://www.nature.com/articles/s42256-023-00652-2](https://www.nature.com/articles/s42256-023-00652-2) -Additonally, here you can see how others used MedPerf already: [https://scholar.google.com/scholar?q="medperf"](https://scholar.google.com/scholar?q="medperf"). +Additionally, here you can see how others used MedPerf already: [https://scholar.google.com/scholar?q="medperf"](https://scholar.google.com/scholar?q="medperf"). ## Experiments From 8f3ec48771bfc4abf14f34375f075299c6e02fc2 Mon Sep 17 00:00:00 2001 From: Viacheslav Kukushkin Date: Fri, 14 Jun 2024 19:51:56 +0300 Subject: [PATCH 15/18] added cube edit command --- cli/medperf/commands/mlcube/edit.py | 106 ++++++++++++++++++++++++++ cli/medperf/commands/mlcube/mlcube.py | 67 ++++++++++++++++ cli/medperf/comms/interface.py | 12 +++ cli/medperf/comms/rest.py | 19 +++++ cli/medperf/entities/cube.py | 20 +++++ cli/medperf/entities/edit_cube.py | 20 +++++ cli/medperf/entities/interface.py | 6 ++ 7 files changed, 250 insertions(+) create mode 100644 cli/medperf/commands/mlcube/edit.py create mode 100644 cli/medperf/entities/edit_cube.py diff --git a/cli/medperf/commands/mlcube/edit.py b/cli/medperf/commands/mlcube/edit.py new file mode 100644 index 000000000..ed4307089 --- /dev/null +++ b/cli/medperf/commands/mlcube/edit.py @@ -0,0 +1,106 @@ +import logging +from typing import Union + +import medperf.config as config +from medperf.entities.cube import Cube +from medperf.entities.edit_cube import EditCubeData + + +class EditCube: + @classmethod + def run(cls, cube_uid: Union[str, int], mlcube_partial_info: EditCubeData): + """Update mlcube in the development mode on the medperf server + + Args: + cube_uid: uid of cube to modify + mlcube_partial_info (dict): Dictionary containing the modified fields. + """ + ui = config.ui + + logging.debug("Downloading initial MLCube..") + edition = cls(cube_uid, mlcube_partial_info) + logging.debug("Validating MLCube DEVELOPMENT state..") + edition.validate_dev_state() + + with ui.interactive(): + ui.text = "Validating updated MLCube can be downloaded" + logging.debug("Applying MLCube edit..") + edition.apply() + ui.text = "Submitting MLCube edit to MedPerf" + logging.debug("Uploading MLCube..") + edition.upload() + edition.write() + + def __init__(self, cube_uid: Union[str, int], edit_info: EditCubeData): + self.ui = config.ui + self.cube = Cube.get(cube_uid) + self.edit_info = edit_info + + def validate_dev_state(self): + if self.cube.state != "DEVELOPMENT": + raise ValueError("Only cubes in development state can be edited") + + def apply(self): + cube = self.cube + new = self.edit_info + + if new.name: + cube.name = new.name + + if new.git_mlcube_url: + cube.git_mlcube_url = new.git_mlcube_url + + if new.git_mlcube_hash: + cube.git_mlcube_hash = new.git_mlcube_hash + elif new.git_mlcube_url is not None: + cube.git_mlcube_hash = "" + + if new.git_parameters_url: + cube.git_parameters_url = new.git_parameters_url + + if new.parameters_hash: + cube.parameters_hash = new.parameters_hash + elif new.git_parameters_url is not None: + cube.parameters_hash = "" + + if new.image_tarball_url: + cube.image_tarball_url = new.image_tarball_url + + if new.image_tarball_hash: + cube.image_tarball_hash = new.image_tarball_hash + elif new.image_tarball_url is not None: + cube.image_tarball_hash = "" + + if new.additional_files_tarball_url: + cube.additional_files_tarball_url = new.additional_files_tarball_url + + if new.additional_files_tarball_hash: + cube.additional_files_tarball_hash = new.additional_files_tarball_hash + elif new.additional_files_tarball_url is not None: + cube.additional_files_tarball_hash = "" + + self.download() + + if new.git_mlcube_hash == "": + new.git_mlcube_hash = cube.git_mlcube_hash + if new.parameters_hash == "": + new.parameters_hash = cube.parameters_hash + if new.image_tarball_hash == "": + new.image_tarball_hash = cube.image_tarball_hash + if new.additional_files_tarball_hash == "": + new.additional_files_tarball_hash = cube.additional_files_tarball_hash + + def download(self): + logging.debug("removing from filesystem...") + self.cube.remove_from_filesystem() + logging.debug("download config files..") + self.cube.download_config_files() + logging.debug("download run files..") + self.cube.download_run_files() + + def upload(self): + updated_body = Cube.edit(self.cube.id, self.edit_info) + self.cube = Cube(**updated_body) + + def write(self): + self.cube.write() diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 9256f35f2..5d36160bc 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -4,11 +4,13 @@ import medperf.config as config from medperf.decorators import clean_except from medperf.entities.cube import Cube +from medperf.entities.edit_cube import EditCubeData from medperf.commands.list import EntityList from medperf.commands.view import EntityView from medperf.commands.mlcube.create import CreateCube from medperf.commands.mlcube.submit import SubmitCube from medperf.commands.mlcube.associate import AssociateCube +from medperf.commands.mlcube.edit import EditCube app = typer.Typer() @@ -123,6 +125,71 @@ def submit( config.ui.print("✅ Done!") +@app.command("edit") +@clean_except +def edit( + uid: str = typer.Option(..., "--uid", "-u", help="UID of the MLCube to edit"), + name: str = typer.Option(None, "--name", "-n", help="Name of the mlcube"), + mlcube_file: str = typer.Option( + None, + "--mlcube-file", + "-m", + help="Identifier to download the mlcube file. See the description above", + ), + mlcube_hash: str = typer.Option(None, "--mlcube-hash", help="hash of mlcube file"), + parameters_file: str = typer.Option( + None, + "--parameters-file", + "-p", + help="Identifier to download the parameters file. See the description above", + ), + parameters_hash: str = typer.Option(None, "--parameters-hash", help="hash of parameters file"), + additional_file: str = typer.Option( + None, + "--additional-file", + "-a", + help="Identifier to download the additional files tarball. See the description above", + ), + additional_hash: str = typer.Option(None, "--additional-hash", help="hash of additional file"), + image_file: str = typer.Option( + None, + "--image-file", + "-i", + help="Identifier to download the image file. See the description above", + ), + image_hash: str = typer.Option(None, "--image-hash", help="hash of image file"), +): + """Updates the existing mlcube. Only mlcubes in DEVELOPMENT state may be updated.\n + The following assets:\n + - mlcube_file\n + - parameters_file\n + - additional_file\n + - image_file\n + are expected to be given in the following format: + where `source_prefix` instructs the client how to download the resource, and `resource_identifier` + is the identifier used to download the asset. The following are supported:\n + 1. A direct link: "direct:"\n + 2. An asset hosted on the Synapse platform: "synapse:"\n\n + + If a URL is given without a source prefix, it will be treated as a direct download link. + """ + + mlcube_partial_info = EditCubeData( + uid=uid, + name=name, + git_mlcube_url=mlcube_file, + git_mlcube_hash=mlcube_hash, + git_parameters_url=parameters_file, + parameters_hash=parameters_hash, + image_tarball_url=image_file, + image_tarball_hash=image_hash, + additional_files_tarball_url=additional_file, + additional_files_tarball_hash=additional_hash, + ) + EditCube.run(uid, mlcube_partial_info) + config.ui.print("✅ Done!") + + @app.command("associate") @clean_except def associate( diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index 01436e435..c8635b3eb 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -87,6 +87,18 @@ def get_cube_metadata(self, cube_uid: int) -> dict: dict: Dictionary containing url and hashes for the cube files """ + @abstractmethod + def edit_cube(self, cube_uid: int, edited_fields: dict) -> dict: + """Updates mlcube with dict of changed fields + + Args: + cube_uid (int): UID of the desired cube. + edited_fields: Dictionary containing the fields to be updated + + Returns: + dict: Dictionary containing the full mlcube + """ + @abstractmethod def get_user_cubes(self) -> List[dict]: """Retrieves metadata from all cubes registered by the user diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 5ac236f93..9ccb362dc 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -227,6 +227,25 @@ def get_cube_metadata(self, cube_uid: int) -> dict: ) return res.json() + def edit_cube(self, cube_uid: int, edited_fields: dict) -> dict: + """Updates mlcube with dict of changed fields + + Args: + cube_uid (int): UID of the desired cube. + edited_fields: Dictionary containing the fields to be updated + + Returns: + dict: Dictionary containing the full mlcube + """ + res = self.__auth_put(f"{self.server_url}/mlcubes/{cube_uid}/", json=edited_fields) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"the specified cube doesn't exist {details}" + ) + return res.json() + def get_user_cubes(self) -> List[dict]: """Retrieves metadata from all cubes registered by the user diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 714342c53..cd837b62c 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -12,6 +12,7 @@ generate_tmp_path, spawn_and_kill, ) +from medperf.entities.edit_cube import EditCubeData from medperf.entities.interface import Entity from medperf.entities.schemas import DeployableSchema from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError @@ -62,6 +63,12 @@ def get_metadata_filename(): def get_comms_uploader(): return config.comms.upload_mlcube + # as currently edit is implemented only for mlcubes, this function is not defined + # in interface and thus is not overridden. + @staticmethod + def get_comms_edit(): + return config.comms.edit_cube + def __init__(self, *args, **kwargs): """Creates a Cube instance @@ -367,3 +374,16 @@ def display_dict(self): "Created At": self.created_at, "Registered": self.is_registered, } + + @staticmethod + def edit(cube_uid: Union[str, int], edited_fields: EditCubeData) -> Dict: + """Uploads the mlcube diff and updates the entity + + Returns: + Dict: Dictionary with the updated cube + """ + + comms_func = Cube.get_comms_edit() + logging.debug(f"Editing cube {cube_uid} with fields: {edited_fields}") + updated_body = comms_func(cube_uid, edited_fields.not_null_dict()) + return updated_body diff --git a/cli/medperf/entities/edit_cube.py b/cli/medperf/entities/edit_cube.py new file mode 100644 index 000000000..99d85b5a2 --- /dev/null +++ b/cli/medperf/entities/edit_cube.py @@ -0,0 +1,20 @@ +from typing import Union, Optional +from pydantic import BaseModel + + +class EditCubeData(BaseModel): + """represents a partial mlcube with fields to be updated""" + uid: Union[str, int] + name: Optional[str] + git_mlcube_url: Optional[str] + git_mlcube_hash: Optional[str] + git_parameters_url: Optional[str] + parameters_hash: Optional[str] + image_tarball_url: Optional[str] + image_tarball_hash: Optional[str] + additional_files_tarball_url: Optional[str] + additional_files_tarball_hash: Optional[str] + + def not_null_dict(self): + """returns a dictionary of the non-null fields""" + return {k: v for k, v in self.dict().items() if v is not None} diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index 835fbdf22..cff39768d 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,3 +1,4 @@ +import shutil from typing import List, Dict, Union, Callable from abc import ABC import logging @@ -199,6 +200,11 @@ def write(self) -> str: yaml.dump(data, f) return entity_file + def remove_from_filesystem(self): + """Removes the entity folder recursively from the local storage""" + # TODO: might be dangerous + shutil.rmtree(self.path, ignore_errors=True) + def upload(self) -> Dict: """Upload the entity-related information to the communication's interface From 91b512b1a62da834c93ce6d1903a2627e5295dd3 Mon Sep 17 00:00:00 2001 From: Viacheslav Kukushkin Date: Sat, 15 Jun 2024 21:35:31 +0300 Subject: [PATCH 16/18] image_hash should be updated also --- cli/medperf/commands/mlcube/edit.py | 26 ++++++++++++++++---------- cli/medperf/entities/edit_cube.py | 3 ++- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/cli/medperf/commands/mlcube/edit.py b/cli/medperf/commands/mlcube/edit.py index ed4307089..31b44df00 100644 --- a/cli/medperf/commands/mlcube/edit.py +++ b/cli/medperf/commands/mlcube/edit.py @@ -25,7 +25,7 @@ def run(cls, cube_uid: Union[str, int], mlcube_partial_info: EditCubeData): with ui.interactive(): ui.text = "Validating updated MLCube can be downloaded" logging.debug("Applying MLCube edit..") - edition.apply() + edition.apply_and_get_hashes() ui.text = "Submitting MLCube edit to MedPerf" logging.debug("Uploading MLCube..") edition.upload() @@ -40,7 +40,7 @@ def validate_dev_state(self): if self.cube.state != "DEVELOPMENT": raise ValueError("Only cubes in development state can be edited") - def apply(self): + def apply_and_get_hashes(self): cube = self.cube new = self.edit_info @@ -49,11 +49,13 @@ def apply(self): if new.git_mlcube_url: cube.git_mlcube_url = new.git_mlcube_url + # Differs from further ifs: if mlcube.yaml url is provided, reset image also + cube.image_hash = "" - if new.git_mlcube_hash: - cube.git_mlcube_hash = new.git_mlcube_hash + if new.mlcube_hash: + cube.mlcube_hash = new.mlcube_hash elif new.git_mlcube_url is not None: - cube.git_mlcube_hash = "" + cube.mlcube_hash = "" if new.git_parameters_url: cube.git_parameters_url = new.git_parameters_url @@ -65,6 +67,8 @@ def apply(self): if new.image_tarball_url: cube.image_tarball_url = new.image_tarball_url + # same as with git_mlcube_url + cube.image_hash = "" if new.image_tarball_hash: cube.image_tarball_hash = new.image_tarball_hash @@ -81,14 +85,16 @@ def apply(self): self.download() - if new.git_mlcube_hash == "": - new.git_mlcube_hash = cube.git_mlcube_hash - if new.parameters_hash == "": + if new.git_mlcube_url and not new.mlcube_hash: + new.mlcube_hash = cube.mlcube_hash + if new.git_parameters_url and not new.parameters_hash: new.parameters_hash = cube.parameters_hash - if new.image_tarball_hash == "": + if new.image_tarball_url and not new.image_tarball_hash: new.image_tarball_hash = cube.image_tarball_hash - if new.additional_files_tarball_hash == "": + if new.additional_files_tarball_url and not new.additional_files_tarball_hash: new.additional_files_tarball_hash = cube.additional_files_tarball_hash + if new.git_mlcube_url or new.image_tarball_url: + new.image_hash = cube.image_hash def download(self): logging.debug("removing from filesystem...") diff --git a/cli/medperf/entities/edit_cube.py b/cli/medperf/entities/edit_cube.py index 99d85b5a2..f4f68bda7 100644 --- a/cli/medperf/entities/edit_cube.py +++ b/cli/medperf/entities/edit_cube.py @@ -7,13 +7,14 @@ class EditCubeData(BaseModel): uid: Union[str, int] name: Optional[str] git_mlcube_url: Optional[str] - git_mlcube_hash: Optional[str] + mlcube_hash: Optional[str] git_parameters_url: Optional[str] parameters_hash: Optional[str] image_tarball_url: Optional[str] image_tarball_hash: Optional[str] additional_files_tarball_url: Optional[str] additional_files_tarball_hash: Optional[str] + image_hash: Optional[str] = None def not_null_dict(self): """returns a dictionary of the non-null fields""" From 3a4614b5957f20253b1cc7187afb70dba6f985f5 Mon Sep 17 00:00:00 2001 From: Viacheslav Kukushkin Date: Sun, 16 Jun 2024 15:06:42 +0300 Subject: [PATCH 17/18] These params are (1) never used, (2) broken if no other params are given, pytest.fixture.params is treated as list of and one param is passed to test function at once. In our case it would be just str keys. --- cli/medperf/tests/entities/test_benchmark.py | 10 +--------- cli/medperf/tests/entities/test_cube.py | 9 +-------- cli/medperf/tests/entities/test_entity.py | 9 +-------- 3 files changed, 3 insertions(+), 25 deletions(-) diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index c36771e12..1a866f0fc 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -7,15 +7,7 @@ PATCH_BENCHMARK = "medperf.entities.benchmark.{}" -@pytest.fixture( - params={ - "unregistered": ["b1", "b2"], - "local": ["b1", "b2", 1, 2, 3], - "remote": [1, 2, 3, 4, 5, 6], - "user": [4], - "models": [10, 11], - } -) +@pytest.fixture(autouse=True) def setup(request, mocker, comms, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 89e7cc5a9..d8f607660 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -24,14 +24,7 @@ } -@pytest.fixture( - params={ - "unregistered": ["c1", "c2"], - "local": ["c1", "c2", 1, 2, 3], - "remote": [1, 2, 3, 4, 5, 6], - "user": [4], - } -) +@pytest.fixture(autouse=True) def setup(request, mocker, comms, fs): local_ents = request.param.get("local", []) remote_ents = request.param.get("remote", []) diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index 5f2d24b3a..d25c8ea0e 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -23,14 +23,7 @@ def Implementation(request): return request.param -@pytest.fixture( - params={ - "unregistered": ["e1", "e2"], - "local": ["e1", "e2", 1, 2, 3], - "remote": [1, 2, 3, 4, 5, 6], - "user": [4], - } -) +@pytest.fixture(autouse=True) def setup(request, mocker, comms, Implementation, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) From 6a002dc8a87d175cfbf849df8bc5824fb8533191 Mon Sep 17 00:00:00 2001 From: Viacheslav Kukushkin Date: Sun, 16 Jun 2024 21:47:54 +0300 Subject: [PATCH 18/18] Improving tests: making a test storage New storage allows to use both get/upload/edit in the same test (say, upload and then get & check) --- cli/medperf/entities/cube.py | 2 +- cli/medperf/tests/entities/test_benchmark.py | 11 +- cli/medperf/tests/entities/test_cube.py | 4 +- cli/medperf/tests/entities/test_entity.py | 31 +++-- cli/medperf/tests/entities/utils.py | 31 ++--- cli/medperf/tests/mocks/comms.py | 117 +++++++++---------- cli/medperf/tests/mocks/cube.py | 9 +- cli/medperf/tests/mocks/dataset.py | 5 +- 8 files changed, 109 insertions(+), 101 deletions(-) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index cd837b62c..fa250083e 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -251,7 +251,7 @@ def run( Defaults to {}. timeout (int, optional): timeout for the task in seconds. Defaults to None. read_protected_input (bool, optional): Wether to disable write permissions on input volumes. Defaults to True. - kwargs (dict): additional arguments that are passed directly to the mlcube command + kwargs: additional arguments that are passed directly to the mlcube command """ kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 1a866f0fc..6fa6aae47 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -3,7 +3,6 @@ from medperf.entities.benchmark import Benchmark from medperf.tests.entities.utils import setup_benchmark_fs, setup_benchmark_comms - PATCH_BENCHMARK = "medperf.entities.benchmark.{}" @@ -14,12 +13,10 @@ def setup(request, mocker, comms, fs): user_ids = request.param.get("user", []) models = request.param.get("models", []) # Have a list that will contain all uploaded entities of the given type - uploaded = [] setup_benchmark_fs(local_ids, fs) - setup_benchmark_comms(mocker, comms, remote_ids, user_ids, uploaded) + setup_benchmark_comms(mocker, comms, remote_ids, user_ids) mocker.patch.object(comms, "get_benchmark_model_associations", return_value=models) - request.param["uploaded"] = uploaded return request.param @@ -44,10 +41,10 @@ def setup(request, mocker, comms, fs): class TestModels: def test_benchmark_get_models_works_as_expected(self, setup, expected_models): # Arrange - id = setup["remote"][0] + id_ = setup["remote"][0] # Act - assciated_models = Benchmark.get_models_uids(id) + associated_models = Benchmark.get_models_uids(id_) # Assert - assert set(assciated_models) == set(expected_models) + assert set(associated_models) == set(expected_models) diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index d8f607660..cd182489f 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -30,12 +30,10 @@ def setup(request, mocker, comms, fs): remote_ents = request.param.get("remote", []) user_ents = request.param.get("user", []) # Have a list that will contain all uploaded entities of the given type - uploaded = [] setup_cube_fs(local_ents, fs) - setup_cube_comms(mocker, comms, remote_ents, user_ents, uploaded) + request.param["storage"] = setup_cube_comms(mocker, comms, remote_ents, user_ents) setup_cube_comms_downloads(mocker, fs) - request.param["uploaded"] = uploaded # Mock additional third party elements mpexpect = MockPexpect(0) diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index d25c8ea0e..960a8edb2 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -16,6 +16,7 @@ setup_result_comms, ) from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.tests.mocks.comms import TestEntityStorage @pytest.fixture(params=[Benchmark, Cube, Dataset, Result]) @@ -29,7 +30,6 @@ def setup(request, mocker, comms, Implementation, fs): remote_ids = request.param.get("remote", []) user_ids = request.param.get("user", []) # Have a list that will contain all uploaded entities of the given type - uploaded = [] if Implementation == Benchmark: setup_fs = setup_benchmark_fs @@ -44,10 +44,13 @@ def setup(request, mocker, comms, Implementation, fs): elif Implementation == Result: setup_fs = setup_result_fs setup_comms = setup_result_comms + else: + raise NotImplementedError("Wrong implementation") - setup_comms(mocker, comms, remote_ids, user_ids, uploaded) + storage = setup_comms(mocker, comms, remote_ids, user_ids) setup_fs(local_ids, fs) - request.param["uploaded"] = uploaded + + request.param["storage"] = storage return request.param @@ -167,14 +170,16 @@ def set_common_attributes(self, setup): def test_upload_adds_to_remote(self, Implementation, setup): # Arrange - uploaded_entities = setup["uploaded"] + storage: TestEntityStorage = setup["storage"] + assert self.id not in storage.storage + ent = Implementation.get(self.id) # Act ent.upload() # Assert - assert ent.todict() in uploaded_entities + assert ent.todict() in storage.uploaded def test_upload_returns_dict(self, Implementation): # Arrange @@ -184,20 +189,30 @@ def test_upload_returns_dict(self, Implementation): ent_dict = ent.upload() # Assert - assert ent_dict == ent.todict() + real_dict = ent.todict() + diff = {} + for k in set(real_dict) | set(ent_dict): + if real_dict.get(k) != ent_dict.get(k): + diff[k] = (real_dict.get(k), ent_dict.get(k)) + assert ent_dict == ent.todict(), f"Expected: {ent_dict}\nGot: {real_dict}\nDiff: {diff}" def test_upload_fails_for_test_entity(self, Implementation, setup): # Arrange - uploaded_entities = setup["uploaded"] + storage: TestEntityStorage = setup["storage"] ent = Implementation.get(self.id) ent.for_test = True + # pre-check + len_before_test = len(storage.uploaded) + assert self.id not in storage.storage # Act with pytest.raises(InvalidArgumentError): ent.upload() # Assert - assert ent.todict() not in uploaded_entities + assert self.id not in storage.storage + assert ent.todict() not in storage.uploaded + assert len(storage.uploaded) == len_before_test @pytest.mark.parametrize( diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index c3bde6feb..98c3a5070 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -1,4 +1,5 @@ import os + from medperf import config import yaml @@ -8,7 +9,7 @@ from medperf.tests.mocks.dataset import TestDataset from medperf.tests.mocks.result import TestResult from medperf.tests.mocks.cube import TestCube -from medperf.tests.mocks.comms import mock_comms_entity_gets +from medperf.tests.mocks.comms import mock_comms_entity_gets, TestEntityStorage PATCH_RESOURCES = "medperf.comms.entity_resources.resources.{}" @@ -37,7 +38,7 @@ def setup_benchmark_fs(ents, fs): pass -def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_benchmark_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestBenchmark comms_calls = { "get_all": "get_benchmarks", @@ -46,8 +47,8 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): "upload_instance": "upload_benchmark", } mocker.patch.object(comms, "get_benchmark_model_associations", return_value=[]) - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) @@ -70,7 +71,7 @@ def setup_cube_fs(ents, fs): pass -def setup_cube_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_cube_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestCube comms_calls = { "get_all": "get_cubes", @@ -78,15 +79,15 @@ def setup_cube_comms(mocker, comms, all_ents, user_ents, uploaded): "get_instance": "get_cube_metadata", "upload_instance": "upload_mlcube", } - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) def generate_cubefile_fn(fs, path, filename): # all_ids = [ent["id"] if type(ent) == dict else ent for ent in all_ents] - def cubefile_fn(url, cube_path, *args): + def cubefile_fn(url: str, cube_path: str, *args): if url == "broken_url": raise CommunicationRetrievalError filepath = os.path.join(cube_path, path, filename) @@ -144,7 +145,7 @@ def setup_dset_fs(ents, fs): pass -def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_dset_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestDataset comms_calls = { "get_all": "get_datasets", @@ -152,8 +153,8 @@ def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): "get_instance": "get_dataset", "upload_instance": "upload_dataset", } - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) @@ -182,7 +183,7 @@ def setup_result_fs(ents, fs): pass -def setup_result_comms(mocker, comms, all_ents, user_ents, uploaded): +def setup_result_comms(mocker, comms, all_ents, user_ents) -> TestEntityStorage: generate_fn = TestResult comms_calls = { "get_all": "get_results", @@ -192,7 +193,7 @@ def setup_result_comms(mocker, comms, all_ents, user_ents, uploaded): } # Enable dset retrieval since its required for result creation - setup_dset_comms(mocker, comms, [1], [1], uploaded) - mock_comms_entity_gets( - mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded + setup_dset_comms(mocker, comms, [1], [1]) + return mock_comms_entity_gets( + mocker, comms, generate_fn, comms_calls, all_ents, user_ents ) diff --git a/cli/medperf/tests/mocks/comms.py b/cli/medperf/tests/mocks/comms.py index 57084c849..73c1619bd 100644 --- a/cli/medperf/tests/mocks/comms.py +++ b/cli/medperf/tests/mocks/comms.py @@ -1,11 +1,44 @@ # Utility functions for mocking comms and its methods -from typing import Dict, List, Callable, Union +from typing import Dict, List, Callable, Union, TypeVar, Tuple from unittest.mock import MagicMock from pytest_mock.plugin import MockFixture - from medperf.exceptions import CommunicationRetrievalError +class TestEntityStorage: + AnyEntity = TypeVar("AnyEntity") + + def __init__(self, + generate_fun: Callable[[Dict], AnyEntity], + ents: Dict[str, Dict]): + + self.storage = ents + self.uploaded = [] + self.generate_fun = generate_fun # 🥳 <- generated fun + + def get(self, id_) -> Dict: + if id_ not in self.storage: + raise CommunicationRetrievalError(f"Get entity {id_}: not found in test storage") + return self.storage[id_] + + def upload(self, ent: Dict) -> Dict: + id_ = ent["id"] + if id_ is None or id_ == "": # not include 0 as 0 is a valid id + id_ = str(-len(self.storage)) # some non-existent id + assert id_ not in self.storage, f"Upload failed: generated id {id_} already exists in storage" + self.storage[id_] = self.generate_fun(**ent).todict() + self.uploaded.append(ent) + return self.storage[id_] + + def edit(self, ent: Dict): + id_ = ent["id"] + if id_ not in self.storage: + raise CommunicationRetrievalError(f"Edit entity {id_}: not found in test storage") + orig_value = self.storage[id_] + new_value = {**orig_value, **ent} # rewrites all fields from ent + self.storage[id_] = new_value + + def mock_comms_entity_gets( mocker: MockFixture, comms: MagicMock, @@ -13,8 +46,7 @@ def mock_comms_entity_gets( comms_calls: Dict[str, str], all_ents: List[Union[str, Dict]], user_ents: List[Union[str, Dict]], - uploaded: List, -): +) -> TestEntityStorage: """Mocks API endpoints used by an entity instance. Allows to define what is returned by each endpoint, and keeps track of submitted instances. @@ -28,69 +60,34 @@ def mock_comms_entity_gets( - get_user - get_instance - upload_instance - all_ids (List[Union[str, Dict]]): List of ids or curations that should be returned by the all endpoint - user_ids (List[Union[str, Dict]]): List of ids or configurations that should be returned by the user endpoint - uploaded (List): List that will be updated with uploaded instances + - edit_instance [optional] + all_ents (List[Union[str, Dict]]): List of ids or configurations to init storage. Should be returned by the + `all` endpoint. + user_ents (List[Union[str, Dict]]): List of ids or configurations that should be returned by the user endpoint. + Non-updatable. + Returns: + TestStorage: A link to the storage. Whenever new entity is uploaded / edited, it is updated """ get_all = comms_calls["get_all"] get_user = comms_calls["get_user"] get_instance = comms_calls["get_instance"] upload_instance = comms_calls["upload_instance"] - all_ents = [ent if isinstance(ent, dict) else {"id": ent} for ent in all_ents] - user_ents = [ent if isinstance(ent, dict) else {"id": ent} for ent in user_ents] - - instances = [generate_fn(**ent).dict() for ent in all_ents] - user_instances = [generate_fn(**ent).dict() for ent in user_ents] - mocker.patch.object(comms, get_all, return_value=instances) - mocker.patch.object(comms, get_user, return_value=user_instances) - get_behavior = get_comms_instance_behavior(generate_fn, all_ents) - mocker.patch.object( - comms, - get_instance, - side_effect=get_behavior, - ) - upload_behavior = upload_comms_instance_behavior(uploaded) - mocker.patch.object(comms, upload_instance, side_effect=upload_behavior) - - -def get_comms_instance_behavior( - generate_fn: Callable, ents: List[Union[str, Dict]] -) -> Callable: - """Function that defines a GET behavior - - Args: - generate_fn (Callable): Function to generate entity dictionaries - ents (List[Union[str, Dict]]): List of Entities configurations that are allowed to return - - Return: - function: Function that returns an entity dictionary if found, - or raises an error if not - """ - ids = [ent["id"] if isinstance(ent, dict) else ent for ent in ents] - - def get_behavior(id: int): - if id in ids: - idx = ids.index(id) - return generate_fn(**ents[idx]).dict() + def _to_dict_entity(ent: Union[str, Dict]) -> Tuple[str, Dict]: + """returns pair (id, entity-as-a-full-dict)""" + if isinstance(ent, dict): + id_, ent_params = ent["id"], ent else: - raise CommunicationRetrievalError - - return get_behavior + id_, ent_params = ent, {"id": ent} + return id_, generate_fn(**ent_params).dict() + all_ents = dict(_to_dict_entity(ent) for ent in all_ents) + user_ents = dict(_to_dict_entity(ent) for ent in user_ents) -def upload_comms_instance_behavior(uploaded: List) -> Callable: - """Function that defines the comms mocked behavior when uploading entities - - Args: - uploaded (List): List that will be updated with uploaded entities - - Returns: - Callable: Function containing the desired behavior - """ - - def upload_behavior(entity_dict): - uploaded.append(entity_dict) - return entity_dict + storage = TestEntityStorage(generate_fn, all_ents) + mocker.patch.object(comms, get_all, return_value=list(all_ents.values())) + mocker.patch.object(comms, get_user, return_value=list(user_ents.values())) - return upload_behavior + mocker.patch.object(comms, get_instance, side_effect=storage.get) + mocker.patch.object(comms, upload_instance, side_effect=storage.upload) + return storage diff --git a/cli/medperf/tests/mocks/cube.py b/cli/medperf/tests/mocks/cube.py index 9c1acbb8a..b8f8d828e 100644 --- a/cli/medperf/tests/mocks/cube.py +++ b/cli/medperf/tests/mocks/cube.py @@ -1,6 +1,6 @@ from typing import Optional from medperf.entities.cube import Cube - +from pydantic import Field EMPTY_FILE_HASH = "da39a3ee5e6b4b0d3255bfef95601890afd80709" @@ -14,9 +14,10 @@ class TestCube(Cube): parameters_hash: Optional[str] = EMPTY_FILE_HASH image_tarball_url: Optional[str] = "https://test.com/image.tar.gz" image_tarball_hash: Optional[str] = EMPTY_FILE_HASH - additional_files_tarball_url: Optional[str] = ( - "https://test.com/additional_files.tar.gz" + additional_files_tarball_url: Optional[str] = Field( + "https://test.com/additional_files.tar.gz", + alias="tarball_url" ) - additional_files_tarball_hash: Optional[str] = EMPTY_FILE_HASH + additional_files_tarball_hash: Optional[str] = Field(EMPTY_FILE_HASH, alias="tarball_hash") state: str = "OPERATION" is_valid = True diff --git a/cli/medperf/tests/mocks/dataset.py b/cli/medperf/tests/mocks/dataset.py index b3d6d4217..faed25876 100644 --- a/cli/medperf/tests/mocks/dataset.py +++ b/cli/medperf/tests/mocks/dataset.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from medperf.enums import Status from medperf.entities.dataset import Dataset +from pydantic import Field class TestDataset(Dataset): @@ -10,7 +10,6 @@ class TestDataset(Dataset): data_preparation_mlcube: Union[int, str] = 1 input_data_hash: str = "input_data_hash" generated_uid: str = "generated_uid" - generated_metadata: dict = {} - status: Status = Status.APPROVED.value + generated_metadata: dict = Field({}, alias="metadata") state: str = "OPERATION" submitted_as_prepared: bool = False