From a11e2643d778a7873e1460561131674ef513a72b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 28 Feb 2024 17:27:05 +0100 Subject: [PATCH 001/242] fl support code --- .gitignore | 1 + TODO | 47 ++ cli/medperf/cli.py | 4 + cli/medperf/commands/aggregator/aggregator.py | 107 +++ cli/medperf/commands/aggregator/associate.py | 38 + cli/medperf/commands/aggregator/run.py | 111 +++ cli/medperf/commands/aggregator/submit.py | 50 ++ cli/medperf/commands/execution.py | 4 +- cli/medperf/commands/training/approve.py | 38 + cli/medperf/commands/training/associate.py | 45 ++ cli/medperf/commands/training/list_assocs.py | 41 ++ cli/medperf/commands/training/lock.py | 23 + cli/medperf/commands/training/run.py | 111 +++ cli/medperf/commands/training/submit.py | 53 ++ cli/medperf/commands/training/training.py | 169 +++++ cli/medperf/comms/rest.py | 326 +++++++++ cli/medperf/config.py | 16 + cli/medperf/cryptography/__init__.py | 3 + cli/medperf/cryptography/ca.py | 150 ++++ cli/medperf/cryptography/io.py | 129 ++++ cli/medperf/cryptography/participant.py | 72 ++ cli/medperf/cryptography/utils.py | 14 + cli/medperf/entities/aggregator.py | 234 ++++++ cli/medperf/entities/cube.py | 10 +- cli/medperf/entities/training_exp.py | 316 +++++++++ cli/medperf/tests/commands/test_execution.py | 4 +- cli/medperf/utils.py | 48 ++ examples/fl/mlcube/mlcube-cpu.yaml | 40 ++ examples/fl/mlcube/mlcube-gpu.yaml | 41 ++ examples/fl/mlcube/workspace/network.yaml | 5 + .../fl/mlcube/workspace/parameters-cpu.yaml | 137 ++++ .../fl/mlcube/workspace/parameters-gpu.yaml | 137 ++++ .../mlcube/workspace/parameters-miccai.yaml | 166 +++++ examples/fl/project/Dockerfile-CPU | 31 + examples/fl/project/Dockerfile-GPU | 25 + examples/fl/project/README.md | 43 ++ examples/fl/project/aggregator.py | 50 ++ examples/fl/project/collaborator.py | 29 + examples/fl/project/fl_workspace/.workspace | 2 + .../fl/project/fl_workspace/plan/defaults | 2 + .../fl/project/fl_workspace/requirements.txt | 1 + examples/fl/project/hotfix.py | 667 ++++++++++++++++++ examples/fl/project/mlcube.py | 34 + examples/fl/project/utils.py | 164 +++++ mock_tokens/generate_tokens.py | 12 +- mock_tokens/tokens.json | 14 +- server/aggregator/__init__.py | 0 server/aggregator/admin.py | 3 + server/aggregator/apps.py | 6 + server/aggregator/migrations/0001_initial.py | 31 + server/aggregator/migrations/__init__.py | 0 server/aggregator/models.py | 18 + server/aggregator/serializers.py | 9 + server/aggregator/urls.py | 12 + server/aggregator/views.py | 52 ++ server/aggregator_association/__init__.py | 0 server/aggregator_association/admin.py | 3 + server/aggregator_association/apps.py | 6 + .../migrations/0001_initial.py | 36 + .../0002_experimentaggregator_training_exp.py | 22 + .../migrations/__init__.py | 0 server/aggregator_association/models.py | 29 + server/aggregator_association/permissions.py | 54 ++ server/aggregator_association/serializers.py | 156 ++++ server/aggregator_association/utils.py | 14 + server/aggregator_association/views.py | 76 ++ server/dataset/urls.py | 3 + server/key_storage/__init__.py | 0 server/key_storage/gcloud_secret_manager.py | 11 + server/key_storage/local.py | 20 + server/medperf/settings.py | 12 + server/medperf/urls.py | 2 + server/signing/__init__.py | 0 server/signing/cryptography/__init__.py | 3 + server/signing/cryptography/ca.py | 150 ++++ server/signing/cryptography/io.py | 129 ++++ server/signing/cryptography/participant.py | 72 ++ server/signing/cryptography/utils.py | 14 + server/signing/interface.py | 67 ++ server/testing_medperf.sh | 69 ++ server/testing_miccai.sh | 153 ++++ server/testing_miccai_shortcut.sh | 7 + server/traindataset_association/__init__.py | 0 server/traindataset_association/admin.py | 3 + server/traindataset_association/apps.py | 6 + .../migrations/0001_initial.py | 36 + .../0002_experimentdataset_training_exp.py | 22 + .../migrations/__init__.py | 0 server/traindataset_association/models.py | 29 + .../traindataset_association/permissions.py | 54 ++ .../traindataset_association/serializers.py | 136 ++++ server/traindataset_association/utils.py | 14 + server/traindataset_association/views.py | 72 ++ server/training/__init__.py | 0 server/training/admin.py | 3 + server/training/apps.py | 6 + server/training/migrations/0001_initial.py | 46 ++ server/training/migrations/__init__.py | 0 server/training/models.py | 57 ++ server/training/permissions.py | 27 + server/training/serializers.py | 97 +++ server/training/urls.py | 11 + server/training/views.py | 115 +++ server/utils/urls.py | 7 + server/utils/views.py | 92 +++ 105 files changed, 5822 insertions(+), 14 deletions(-) create mode 100644 TODO create mode 100644 cli/medperf/commands/aggregator/aggregator.py create mode 100644 cli/medperf/commands/aggregator/associate.py create mode 100644 cli/medperf/commands/aggregator/run.py create mode 100644 cli/medperf/commands/aggregator/submit.py create mode 100644 cli/medperf/commands/training/approve.py create mode 100644 cli/medperf/commands/training/associate.py create mode 100644 cli/medperf/commands/training/list_assocs.py create mode 100644 cli/medperf/commands/training/lock.py create mode 100644 cli/medperf/commands/training/run.py create mode 100644 cli/medperf/commands/training/submit.py create mode 100644 cli/medperf/commands/training/training.py create mode 100644 cli/medperf/cryptography/__init__.py create mode 100644 cli/medperf/cryptography/ca.py create mode 100644 cli/medperf/cryptography/io.py create mode 100644 cli/medperf/cryptography/participant.py create mode 100644 cli/medperf/cryptography/utils.py create mode 100644 cli/medperf/entities/aggregator.py create mode 100644 cli/medperf/entities/training_exp.py create mode 100644 examples/fl/mlcube/mlcube-cpu.yaml create mode 100644 examples/fl/mlcube/mlcube-gpu.yaml create mode 100644 examples/fl/mlcube/workspace/network.yaml create mode 100644 examples/fl/mlcube/workspace/parameters-cpu.yaml create mode 100644 examples/fl/mlcube/workspace/parameters-gpu.yaml create mode 100644 examples/fl/mlcube/workspace/parameters-miccai.yaml create mode 100644 examples/fl/project/Dockerfile-CPU create mode 100644 examples/fl/project/Dockerfile-GPU create mode 100644 examples/fl/project/README.md create mode 100644 examples/fl/project/aggregator.py create mode 100644 examples/fl/project/collaborator.py create mode 100644 examples/fl/project/fl_workspace/.workspace create mode 100644 examples/fl/project/fl_workspace/plan/defaults create mode 100644 examples/fl/project/fl_workspace/requirements.txt create mode 100644 examples/fl/project/hotfix.py create mode 100644 examples/fl/project/mlcube.py create mode 100644 examples/fl/project/utils.py create mode 100644 server/aggregator/__init__.py create mode 100644 server/aggregator/admin.py create mode 100644 server/aggregator/apps.py create mode 100644 server/aggregator/migrations/0001_initial.py create mode 100644 server/aggregator/migrations/__init__.py create mode 100644 server/aggregator/models.py create mode 100644 server/aggregator/serializers.py create mode 100644 server/aggregator/urls.py create mode 100644 server/aggregator/views.py create mode 100644 server/aggregator_association/__init__.py create mode 100644 server/aggregator_association/admin.py create mode 100644 server/aggregator_association/apps.py create mode 100644 server/aggregator_association/migrations/0001_initial.py create mode 100644 server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py create mode 100644 server/aggregator_association/migrations/__init__.py create mode 100644 server/aggregator_association/models.py create mode 100644 server/aggregator_association/permissions.py create mode 100644 server/aggregator_association/serializers.py create mode 100644 server/aggregator_association/utils.py create mode 100644 server/aggregator_association/views.py create mode 100644 server/key_storage/__init__.py create mode 100644 server/key_storage/gcloud_secret_manager.py create mode 100644 server/key_storage/local.py create mode 100644 server/signing/__init__.py create mode 100644 server/signing/cryptography/__init__.py create mode 100644 server/signing/cryptography/ca.py create mode 100644 server/signing/cryptography/io.py create mode 100644 server/signing/cryptography/participant.py create mode 100644 server/signing/cryptography/utils.py create mode 100644 server/signing/interface.py create mode 100644 server/testing_medperf.sh create mode 100644 server/testing_miccai.sh create mode 100644 server/testing_miccai_shortcut.sh create mode 100644 server/traindataset_association/__init__.py create mode 100644 server/traindataset_association/admin.py create mode 100644 server/traindataset_association/apps.py create mode 100644 server/traindataset_association/migrations/0001_initial.py create mode 100644 server/traindataset_association/migrations/0002_experimentdataset_training_exp.py create mode 100644 server/traindataset_association/migrations/__init__.py create mode 100644 server/traindataset_association/models.py create mode 100644 server/traindataset_association/permissions.py create mode 100644 server/traindataset_association/serializers.py create mode 100644 server/traindataset_association/utils.py create mode 100644 server/traindataset_association/views.py create mode 100644 server/training/__init__.py create mode 100644 server/training/admin.py create mode 100644 server/training/apps.py create mode 100644 server/training/migrations/0001_initial.py create mode 100644 server/training/migrations/__init__.py create mode 100644 server/training/models.py create mode 100644 server/training/permissions.py create mode 100644 server/training/serializers.py create mode 100644 server/training/urls.py create mode 100644 server/training/views.py diff --git a/.gitignore b/.gitignore index 212d8e13e..3bf5c7327 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,4 @@ cython_debug/ # Dev Environment Specific .vscode .venv +server/keys \ No newline at end of file diff --git a/TODO b/TODO new file mode 100644 index 000000000..3f810f427 --- /dev/null +++ b/TODO @@ -0,0 +1,47 @@ +# TODO: remove me from the repo + +FOR TUTORIAL + +- stream logs +- check benchmark execution mlcube training exp ID +- if association request failed for some reason, delete private key (or at least check if rerunning the request will simply overwrite the key) +- define output folders in medperf storage (logs for both, weights for agg) +- adding email to CN currently could be challenging. THINK + - ASSUMPTION: emails are not changed after signup + +- We now have demo data url and hash in training exp (dummy) that we don't use. + - what to say about this in miccai (I think no worries; it's hidden now) +- rethink/review about the following serializers and if necessary use atomic transactions + - association creation (dataset-training, agg-training) + - association approval (dataset-training, agg-training) + - training experiment creation (creating keypair); this could move to approval +- public/private keys uniqueness constraint while blank; check django docs on how +- fix bug about association list; /home/hasan/work/openfl_ws/medperf-private/server/utils/views.py +- pull latest medperf main +- test agg and training exp owner being same user + - basically, test the tutorial steps EXACTLY + +AFTER TUTORIAL + +- FOLLOWUP: collaborators doesn't use tensorboard logs. +- FOLLOWUP: show csr hash on approval is not necessary since now CSRs are transported securely +- test remote aggregator +- make network config better structured (URL to file? no, could be annoying.) +- move key generation after admin approval of training experiments. +- when the training experiment owner wants to "lock" the experiment + - ask for confirmation? it's an easy command and after execution there is no going back; a mess if unintended. +- secretstorage gcloud + +NOT SURE + +- consider if we want to enable restarts and epochs/"fresh restarts" for training exps (it's hard) +- mlcube for agg alone + +LATER / FUTURE INVESTIGATIONS + +- root key thing. +- limit network access (for now we can rely on the review of the experiment owner) +- compatibility tests +- rethink if keys are always needed (just for exps where they on't need a custom cert) +- server side verification of CSRs (check common names) + - later: the whole design might be changed diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 078d37841..3d4d1ff52 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -16,6 +16,8 @@ import medperf.commands.profile as profile import medperf.commands.association.association as association import medperf.commands.compatibility_test.compatibility_test as compatibility_test +import medperf.commands.training.training as training +import medperf.commands.aggregator.aggregator as aggregator import medperf.commands.storage as storage from medperf.utils import check_for_updates @@ -29,6 +31,8 @@ app.add_typer(compatibility_test.app, name="test", help="Manage compatibility tests") app.add_typer(auth.app, name="auth", help="Authentication") app.add_typer(storage.app, name="storage", help="Storage management") +app.add_typer(training.app, name="training", help="Training") +app.add_typer(aggregator.app, name="aggregator", help="Aggregator") @app.command("run") diff --git a/cli/medperf/commands/aggregator/aggregator.py b/cli/medperf/commands/aggregator/aggregator.py new file mode 100644 index 000000000..4bf5925f2 --- /dev/null +++ b/cli/medperf/commands/aggregator/aggregator.py @@ -0,0 +1,107 @@ +from typing import Optional +from medperf.entities.aggregator import Aggregator +import typer + +import medperf.config as config +from medperf.decorators import clean_except +from medperf.commands.aggregator.submit import SubmitAggregator +from medperf.commands.aggregator.associate import AssociateAggregator +from medperf.commands.aggregator.run import StartAggregator + +from medperf.commands.list import EntityList +from medperf.commands.view import EntityView + +app = typer.Typer() + + +@app.command("submit") +@clean_except +def submit( + name: str = typer.Option(..., "--name", "-n", help="Name of the agg"), + address: str = typer.Option( + ..., "--address", "-a", help="UID of benchmark to associate with" + ), + port: int = typer.Option( + ..., "--port", "-p", help="UID of benchmark to associate with" + ), +): + """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" + SubmitAggregator.run(name, address, port) + config.ui.print("✅ Done!") + + +@app.command("associate") +@clean_except +def associate( + aggregator_id: int = typer.Option( + ..., "--aggregator_id", "-a", help="UID of benchmark to associate with" + ), + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of benchmark to associate with" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" + AssociateAggregator.run(aggregator_id, training_exp_id, approved=approval) + config.ui.print("✅ Done!") + + +@app.command("start") +@clean_except +def run( + aggregator_id: int = typer.Option( + ..., "--aggregator_id", "-a", help="UID of benchmark to associate with" + ), + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of benchmark to associate with" + ), +): + """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" + StartAggregator.run(training_exp_id, aggregator_id) + config.ui.print("✅ Done!") + + +@app.command("ls") +@clean_except +def list( + local: bool = typer.Option(False, "--local", help="Get local aggregators"), + mine: bool = typer.Option(False, "--mine", help="Get current-user aggregators"), +): + """List aggregators stored locally and remotely from the user""" + EntityList.run( + Aggregator, + fields=["UID", "Name", "Address", "Port"], + local_only=local, + mine_only=mine, + ) + + +@app.command("view") +@clean_except +def view( + entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), + format: str = typer.Option( + "yaml", + "-f", + "--format", + help="Format to display contents. Available formats: [yaml, json]", + ), + local: bool = typer.Option( + False, + "--local", + help="Display local benchmarks if benchmark ID is not provided", + ), + mine: bool = typer.Option( + False, + "--mine", + help="Display current-user benchmarks if benchmark ID is not provided", + ), + output: str = typer.Option( + None, + "--output", + "-o", + help="Output file to store contents. If not provided, the output will be displayed", + ), +): + """Displays the information of one or more aggregators""" + EntityView.run(entity_id, Aggregator, format, local, mine, output) diff --git a/cli/medperf/commands/aggregator/associate.py b/cli/medperf/commands/aggregator/associate.py new file mode 100644 index 000000000..0e21a6b66 --- /dev/null +++ b/cli/medperf/commands/aggregator/associate.py @@ -0,0 +1,38 @@ +from medperf import config +from medperf.entities.aggregator import Aggregator +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt, generate_agg_csr +from medperf.exceptions import InvalidArgumentError + + +class AssociateAggregator: + @staticmethod + def run(training_exp_id: int, agg_uid: int, approved=False): + """Associates a registered aggregator with a benchmark + + Args: + agg_uid (int): UID of the registered aggregator to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + agg = Aggregator.get(agg_uid) + if agg.id is None: + msg = "The provided aggregator is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_id) + csr, csr_hash = generate_agg_csr(training_exp_id, agg.address, agg.id) + msg = "Please confirm that you would like to associate" + msg += f" the aggregator {agg.name} with the training exp {training_exp.name}." + msg += f" The certificate signing request hash is: {csr_hash}" + msg += " [Y/n]" + + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating aggregator training association") + # TODO: delete keys if upload fails + # check if on failure, other (possible) request will overwrite key + comms.associate_aggregator(agg.id, training_exp_id, csr) + else: + ui.print("Aggregator association operation cancelled.") diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py new file mode 100644 index 000000000..d81b88ce5 --- /dev/null +++ b/cli/medperf/commands/aggregator/run.py @@ -0,0 +1,111 @@ +import os +from medperf import config +from medperf.exceptions import InvalidArgumentError +from medperf.entities.training_exp import TrainingExp +from medperf.entities.aggregator import Aggregator +from medperf.entities.cube import Cube +from medperf.utils import storage_path + + +class StartAggregator: + @classmethod + def run(cls, training_exp_id: int, agg_uid: int): + """Sets approval status for an association between a benchmark and a aggregator or mlcube + + Args: + benchmark_uid (int): Benchmark UID. + approval_status (str): Desired approval status to set for the association. + comms (Comms): Instance of Comms interface. + ui (UI): Instance of UI interface. + aggregator_uid (int, optional): Aggregator UID. Defaults to None. + mlcube_uid (int, optional): MLCube UID. Defaults to None. + """ + execution = cls(training_exp_id, agg_uid) + execution.prepare() + execution.validate() + execution.prepare_agg_cert() + execution.prepare_cube() + with config.ui.interactive(): + execution.run_experiment() + + def __init__(self, training_exp_id, agg_uid) -> None: + self.training_exp_id = training_exp_id + self.agg_uid = agg_uid + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Execution: {self.training_exp.name}") + self.aggregator = Aggregator.get(self.agg_uid) + + def validate(self): + if self.aggregator.id is None: + msg = "The provided aggregator is not registered." + raise InvalidArgumentError(msg) + + training_exp_aggregator = config.comms.get_experiment_aggregator( + self.training_exp.id + ) + + if self.aggregator.id != training_exp_aggregator["id"]: + msg = "The provided aggregator is not associated." + raise InvalidArgumentError(msg) + + if self.training_exp.state != "OPERATION": + msg = "The provided training exp is not operational." + raise InvalidArgumentError(msg) + + def prepare_agg_cert(self): + association = config.comms.get_aggregator_association( + self.training_exp.id, self.aggregator.id + ) + cert = association["certificate"] + cert_folder = os.path.join( + config.training_exps_storage, + str(self.training_exp.id), + config.agg_cert_folder, + str(self.aggregator.id), + ) + cert_folder = storage_path(cert_folder) + os.makedirs(cert_folder, exist_ok=True) + cert_file = os.path.join(cert_folder, "cert.crt") + with open(cert_file, "w") as f: + f.write(cert) + + self.agg_cert_path = cert_folder + + def prepare_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_mlcube, "training") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = f"Retrieving {name} cube" + cube = Cube.get(uid) + self.ui.print(f"> {name} cube download complete") + return cube + + def run_experiment(self): + task = "start_aggregator" + port = self.aggregator.port + # TODO: this overwrites existing cpu and gpu args + string_params = { + "-Pdocker.cpu_args": f"-p {port}:{port}", + "-Pdocker.gpu_args": f"-p {port}:{port}", + } + + # just for now create some output folders (TODO) + out_logs = os.path.join(self.training_exp.path, "logs") + out_weights = os.path.join(self.training_exp.path, "weights") + os.makedirs(out_logs, exist_ok=True) + os.makedirs(out_weights, exist_ok=True) + + params = { + "node_cert_folder": self.agg_cert_path, + "ca_cert_folder": self.training_exp.cert_path, + "network_config": self.aggregator.network_config_path, + "collaborators": self.training_exp.cols_path, + "output_logs": out_logs, + "output_weights": out_weights, + } + + self.ui.text = "Running Aggregator" + self.cube.run(task=task, string_params=string_params, **params) diff --git a/cli/medperf/commands/aggregator/submit.py b/cli/medperf/commands/aggregator/submit.py new file mode 100644 index 000000000..53335cd63 --- /dev/null +++ b/cli/medperf/commands/aggregator/submit.py @@ -0,0 +1,50 @@ +import medperf.config as config +from medperf.entities.aggregator import Aggregator +from medperf.utils import remove_path + + +class SubmitAggregator: + @classmethod + def run(cls, name, address, port): + """Submits a new cube to the medperf platform + Args: + benchmark_info (dict): benchmark information + expected keys: + name (str): benchmark name + description (str): benchmark description + docs_url (str): benchmark documentation url + demo_url (str): benchmark demo dataset url + demo_hash (str): benchmark demo dataset hash + data_preparation_mlcube (int): benchmark data preparation mlcube uid + reference_model_mlcube (int): benchmark reference model mlcube uid + evaluator_mlcube (int): benchmark data evaluator mlcube uid + """ + ui = config.ui + submission = cls(name, address, port) + + with ui.interactive(): + ui.text = "Submitting Aggregator to MedPerf" + updated_benchmark_body = submission.submit() + ui.print("Uploaded") + submission.write(updated_benchmark_body) + + def __init__(self, name, address, port): + self.ui = config.ui + # TODO: server config should be a URL... + server_config = { + "address": address, + "agg_addr": address, + "port": port, + "agg_port": port, + } + self.aggregator = Aggregator(name=name, server_config=server_config) + config.tmp_paths.append(self.aggregator.path) + + def submit(self): + updated_body = self.aggregator.upload() + return updated_body + + def write(self, updated_body): + remove_path(self.aggregator.path) + aggregator = Aggregator(**updated_body) + aggregator.write() diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index d8afb2244..79d6237b9 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -80,7 +80,7 @@ def run_inference(self): try: self.model.run( task="infer", - output_logs=self.model_logs_path, + output_logs_file=self.model_logs_path, timeout=infer_timeout, data_path=data_path, output_path=preds_path, @@ -105,7 +105,7 @@ def run_evaluation(self): try: self.evaluator.run( task="evaluate", - output_logs=self.metrics_logs_path, + output_logs_file=self.metrics_logs_path, timeout=evaluate_timeout, predictions=preds_path, labels=labels_path, diff --git a/cli/medperf/commands/training/approve.py b/cli/medperf/commands/training/approve.py new file mode 100644 index 000000000..8c42fd127 --- /dev/null +++ b/cli/medperf/commands/training/approve.py @@ -0,0 +1,38 @@ +from medperf import config +from medperf.exceptions import InvalidArgumentError + + +class TrainingAssociationApproval: + @staticmethod + def run( + training_exp_id: int, + approval_status, + data_uid: int = None, + aggregator: int = None, + ): + """Sets approval status for an association between a benchmark and a dataset or mlcube + + Args: + benchmark_uid (int): Benchmark UID. + approval_status (str): Desired approval status to set for the association. + comms (Comms): Instance of Comms interface. + ui (UI): Instance of UI interface. + dataset_uid (int, optional): Dataset UID. Defaults to None. + mlcube_uid (int, optional): MLCube UID. Defaults to None. + """ + comms = config.comms + too_many_resources = data_uid and aggregator + no_resource = data_uid is None and aggregator is None + if no_resource or too_many_resources: + raise InvalidArgumentError("Must provide either a dataset or aggregator") + + if data_uid: + # TODO: show CSR and ask for confirmation + comms.set_training_dataset_association_approval( + training_exp_id, data_uid, approval_status.value + ) + + if aggregator: + comms.set_aggregator_association_approval( + training_exp_id, aggregator, approval_status.value + ) diff --git a/cli/medperf/commands/training/associate.py b/cli/medperf/commands/training/associate.py new file mode 100644 index 000000000..0493d3004 --- /dev/null +++ b/cli/medperf/commands/training/associate.py @@ -0,0 +1,45 @@ +from medperf import config +from medperf.entities.dataset import Dataset +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt, generate_data_csr +from medperf.exceptions import InvalidArgumentError + + +class DatasetTrainingAssociation: + @staticmethod + def run(training_exp_id: int, data_uid: int, approved=False): + """Associates a registered dataset with a benchmark + + Args: + data_uid (int): UID of the registered dataset to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + dset = Dataset.get(data_uid) + if dset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_id) + + if dset.data_preparation_mlcube != training_exp.data_preparation_mlcube: + raise InvalidArgumentError( + "The specified dataset wasn't prepared for this benchmark" + ) + + email = "" # TODO + csr, csr_hash = generate_data_csr(email, data_uid, training_exp_id) + msg = "Please confirm that you would like to associate" + msg += f" the dataset {dset.name} with the training exp {training_exp.name}." + msg += f" The certificate signing request hash is: {csr_hash}" + msg += " [Y/n]" + + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating dataset training association") + # TODO: delete keys if upload fails + # check if on failure, other (possible) request will overwrite key + comms.associate_training_dset(dset.id, training_exp_id, csr) + else: + ui.print("Dataset association operation cancelled.") diff --git a/cli/medperf/commands/training/list_assocs.py b/cli/medperf/commands/training/list_assocs.py new file mode 100644 index 000000000..1b266cbf1 --- /dev/null +++ b/cli/medperf/commands/training/list_assocs.py @@ -0,0 +1,41 @@ +from tabulate import tabulate + +from medperf import config + + +class ListTrainingAssociations: + @staticmethod + def run(filter: str = None): + """Get training association requests""" + comms = config.comms + ui = config.ui + dset_assocs = comms.get_training_datasets_associations() + agg_assocs = comms.get_aggregators_associations() + + # Might be worth seeing if creating an association class that encapsulates + # most of the logic here is useful + assocs = dset_assocs + agg_assocs + if filter: + filter = filter.upper() + assocs = [assoc for assoc in assocs if assoc["approval_status"] == filter] + + assocs_info = [] + for assoc in assocs: + assoc_info = ( + assoc.get("dataset", None), + assoc.get("aggregator", None), + assoc["training_exp"], + assoc["initiated_by"], + assoc["approval_status"], + ) + assocs_info.append(assoc_info) + + headers = [ + "Dataset UID", + "Aggregator UID", + "TrainingExp UID", + "Initiated by", + "Status", + ] + tab = tabulate(assocs_info, headers=headers) + ui.print(tab) diff --git a/cli/medperf/commands/training/lock.py b/cli/medperf/commands/training/lock.py new file mode 100644 index 000000000..eca813ed6 --- /dev/null +++ b/cli/medperf/commands/training/lock.py @@ -0,0 +1,23 @@ +from medperf import config +from medperf.entities.training_exp import TrainingExp + + +class LockTrainingExp: + @staticmethod + def run(training_exp_id: int): + """Sets approval status for an association between a benchmark and a dataset or mlcube + + Args: + benchmark_uid (int): Benchmark UID. + approval_status (str): Desired approval status to set for the association. + comms (Comms): Instance of Comms interface. + ui (UI): Instance of UI interface. + dataset_uid (int, optional): Dataset UID. Defaults to None. + mlcube_uid (int, optional): MLCube UID. Defaults to None. + """ + # TODO: this logic will be refactored when we merge entity edit PR + comms = config.comms + comms.set_experiment_as_operational(training_exp_id) + # update training experiment + training_exp = TrainingExp.get(training_exp_id) + training_exp.write() diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py new file mode 100644 index 000000000..31c030189 --- /dev/null +++ b/cli/medperf/commands/training/run.py @@ -0,0 +1,111 @@ +import os +from medperf import config +from medperf.exceptions import InvalidArgumentError +from medperf.entities.training_exp import TrainingExp +from medperf.entities.dataset import Dataset +from medperf.entities.cube import Cube +from medperf.entities.aggregator import Aggregator +from medperf.utils import storage_path, get_dataset_common_name + + +class TrainingExecution: + @classmethod + def run(cls, training_exp_id: int, data_uid: int): + """Sets approval status for an association between a benchmark and a dataset or mlcube + + Args: + benchmark_uid (int): Benchmark UID. + approval_status (str): Desired approval status to set for the association. + comms (Comms): Instance of Comms interface. + ui (UI): Instance of UI interface. + dataset_uid (int, optional): Dataset UID. Defaults to None. + mlcube_uid (int, optional): MLCube UID. Defaults to None. + """ + execution = cls(training_exp_id, data_uid) + execution.prepare() + execution.validate() + execution.prepare_data_cert() + execution.prepare_network_config() + execution.prepare_cube() + with config.ui.interactive(): + execution.run_experiment() + + def __init__(self, training_exp_id, data_uid) -> None: + self.training_exp_id = training_exp_id + self.data_uid = data_uid + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Execution: {self.training_exp.name}") + self.dataset = Dataset.get(self.data_uid) + + def validate(self): + if self.dataset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + if self.dataset.id not in self.training_exp.datasets: + msg = "The provided dataset is not associated." + raise InvalidArgumentError(msg) + + if self.training_exp.state != "OPERATION": + msg = "The provided training exp is not operational." + raise InvalidArgumentError(msg) + + def prepare_data_cert(self): + association = config.comms.get_training_dataset_association( + self.training_exp.id, self.dataset.id + ) + cert = association["certificate"] + cert_folder = os.path.join( + config.training_exps_storage, + str(self.training_exp.id), + config.data_cert_folder, + str(self.dataset.id), + ) + cert_folder = storage_path(cert_folder) + os.makedirs(cert_folder, exist_ok=True) + cert_file = os.path.join(cert_folder, "cert.crt") + with open(cert_file, "w") as f: + f.write(cert) + + self.data_cert_path = cert_folder + + def prepare_network_config(self): + aggregator = config.comms.get_experiment_aggregator(self.training_exp.id) + aggregator = Aggregator.get(aggregator["id"]) + self.network_config_path = aggregator.network_config_path + + def prepare_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_mlcube, "training") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = f"Retrieving {name} cube" + cube = Cube.get(uid) + self.ui.print(f"> {name} cube download complete") + return cube + + def run_experiment(self): + task = "train" + dataset_cn = get_dataset_common_name("", self.dataset.id, self.training_exp.id) + # TODO: this overwrites existing env args + # TODO: CUDA_VISIBLE_DEVICES should be in dockerfile maybe + string_params = { + "-Pdocker.env_args": f'-e COLLABORATOR_CN={dataset_cn} -e CUDA_VISIBLE_DEVICES="0"', + } + + # just for now create some output folders (TODO) + out_logs = os.path.join(self.training_exp.path, "data_logs") + os.makedirs(out_logs, exist_ok=True) + + params = { + "data_path": self.dataset.data_path, + "labels_path": self.dataset.labels_path, + "node_cert_folder": self.data_cert_path, + "ca_cert_folder": self.training_exp.cert_path, + "network_config": self.network_config_path, + "output_logs": out_logs, + } + self.ui.text = "Training" + self.cube.run(task=task, string_params=string_params, **params) diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py new file mode 100644 index 000000000..89fc82043 --- /dev/null +++ b/cli/medperf/commands/training/submit.py @@ -0,0 +1,53 @@ +import os + +import medperf.config as config +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.utils import remove_path + + +class SubmitTrainingExp: + @classmethod + def run(cls, training_exp_info: dict): + """Submits a new cube to the medperf platform + Args: + benchmark_info (dict): benchmark information + expected keys: + name (str): benchmark name + description (str): benchmark description + docs_url (str): benchmark documentation url + demo_url (str): benchmark demo dataset url + demo_hash (str): benchmark demo dataset hash + data_preparation_mlcube (int): benchmark data preparation mlcube uid + reference_model_mlcube (int): benchmark reference model mlcube uid + evaluator_mlcube (int): benchmark data evaluator mlcube uid + """ + ui = config.ui + submission = cls(training_exp_info) + + with ui.interactive(): + ui.text = "Getting FL MLCube" + submission.get_mlcube() + ui.print("> Completed retrieving FL MLCube") + ui.text = "Submitting TrainingExp to MedPerf" + updated_benchmark_body = submission.submit() + ui.print("Uploaded") + submission.write(updated_benchmark_body) + + def __init__(self, training_exp_info: dict): + self.ui = config.ui + self.training_exp = TrainingExp(**training_exp_info) + config.tmp_paths.append(self.training_exp.path) + + def get_mlcube(self): + mlcube_id = self.training_exp.fl_mlcube + Cube.get(mlcube_id) + + def submit(self): + updated_body = self.training_exp.upload() + return updated_body + + def write(self, updated_body): + remove_path(self.training_exp.path) + training_exp = TrainingExp(**updated_body) + training_exp.write() diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py new file mode 100644 index 000000000..ac1af744f --- /dev/null +++ b/cli/medperf/commands/training/training.py @@ -0,0 +1,169 @@ +from typing import Optional +from medperf.entities.training_exp import TrainingExp +from medperf.enums import Status +import typer + +import medperf.config as config +from medperf.decorators import clean_except + +from medperf.commands.training.submit import SubmitTrainingExp +from medperf.commands.training.run import TrainingExecution +from medperf.commands.training.lock import LockTrainingExp +from medperf.commands.training.associate import DatasetTrainingAssociation +from medperf.commands.training.approve import TrainingAssociationApproval +from medperf.commands.training.list_assocs import ListTrainingAssociations +from medperf.commands.list import EntityList +from medperf.commands.view import EntityView + +app = typer.Typer() + + +@app.command("submit") +@clean_except +def submit( + name: str = typer.Option(..., "--name", "-n", help="Name of the benchmark"), + description: str = typer.Option( + ..., "--description", "-d", help="Description of the benchmark" + ), + docs_url: str = typer.Option("", "--docs-url", "-u", help="URL to documentation"), + prep_mlcube: int = typer.Option( + ..., "--prep-mlcube", "-p", help="prep MLCube UID" + ), + fl_mlcube: int = typer.Option( + ..., "--fl-mlcube", "-m", help="Reference Model MLCube UID" + ), +): + """Submits a new benchmark to the platform""" + training_exp_info = { + "name": name, + "description": description, + "docs_url": docs_url, + "fl_mlcube": fl_mlcube, + "demo_dataset_tarball_url": "link", # TODO later + "demo_dataset_tarball_hash": "hash", + "demo_dataset_generated_uid": "uid", + "data_preparation_mlcube": prep_mlcube, + } + SubmitTrainingExp.run(training_exp_info) + config.ui.print("✅ Done!") + + +@app.command("lock") +@clean_except +def lock( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + LockTrainingExp.run(training_exp_id) + config.ui.print("✅ Done!") + + +@app.command("run") +@clean_except +def run( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + data_uid: int = typer.Option( + ..., "--data_uid", "-d", help="Registered Dataset UID" + ), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + TrainingExecution.run(training_exp_id, data_uid) + config.ui.print("✅ Done!") + + +@app.command("associate_dataset") +@clean_except +def associate( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + data_uid: int = typer.Option( + ..., "--data_uid", "-d", help="Registered Dataset UID" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + DatasetTrainingAssociation.run(training_exp_id, data_uid, approved=approval) + config.ui.print("✅ Done!") + + +@app.command("approve_association") +@clean_except +def approve( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + data_uid: int = typer.Option( + None, "--data_uid", "-d", help="Registered Dataset UID" + ), + aggregator: int = typer.Option( + None, "--aggregator", "-a", help="Registered Dataset UID" + ), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + TrainingAssociationApproval.run( + training_exp_id, Status.APPROVED, data_uid, aggregator + ) + config.ui.print("✅ Done!") + + +@app.command("list_associations") +@clean_except +def list(filter: Optional[str] = typer.Argument(None)): + """Display all training associations related to the current user. + + Args: + filter (str, optional): Filter training associations by approval status. + Defaults to displaying all user training associations. + """ + ListTrainingAssociations.run(filter) + + +@app.command("ls") +@clean_except +def list( + local: bool = typer.Option(False, "--local", help="Get local exps"), + mine: bool = typer.Option(False, "--mine", help="Get current-user exps"), +): + """List experiments stored locally and remotely from the user""" + EntityList.run( + TrainingExp, + fields=["UID", "Name", "State", "Approval Status", "Registered"], + local_only=local, + mine_only=mine, + ) + + +@app.command("view") +@clean_except +def view( + entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), + format: str = typer.Option( + "yaml", + "-f", + "--format", + help="Format to display contents. Available formats: [yaml, json]", + ), + local: bool = typer.Option( + False, + "--local", + help="Display local benchmarks if benchmark ID is not provided", + ), + mine: bool = typer.Option( + False, + "--mine", + help="Display current-user benchmarks if benchmark ID is not provided", + ), + output: str = typer.Option( + None, + "--output", + "-o", + help="Output file to store contents. If not provided, the output will be displayed", + ), +): + """Displays the information of one or more benchmarks""" + EntityView.run(entity_id, TrainingExp, format, local, mine, output) diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 972c0e1c9..9b15dea6d 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -147,6 +147,20 @@ def __set_approval_status(self, url: str, status: str) -> requests.Response: res = self.__auth_put(url, json=data) return res + def __set_state(self, url: str, state: str) -> requests.Response: + """Sets the approval status of a resource + + Args: + url (str): URL to the resource to update + status (str): approval status to set + + Returns: + requests.Response: Response object returned by the update + """ + data = {"state": state} + res = self.__auth_put(url, json=data) + return res + def get_current_user(self): """Retrieve the currently-authenticated user information""" res = self.__auth_get(f"{self.server_url}/me/") @@ -509,3 +523,315 @@ def set_mlcube_association_priority( raise CommunicationRequestError( f"Could not set the priority of mlcube {mlcube_uid} within the benchmark {benchmark_uid}: {details}" ) + + def upload_training_exp(self, training_exp_dict: dict) -> int: + """Uploads a new training_exp to the server. + + Args: + benchmark_dict (dict): benchmark_data to be uploaded + + Returns: + int: UID of newly created benchmark + """ + res = self.__auth_post(f"{self.server_url}/training/", json=training_exp_dict) + if res.status_code != 201: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"Could not upload training exp: {details}" + ) + return res.json() + + def get_training_exp(self, training_exp_id: int) -> dict: + """Retrieves the training_exp specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark + + Returns: + dict: benchmark specification + """ + res = self.__auth_get(f"{self.server_url}/training/{training_exp_id}") + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"the specified training_exp doesn't exist: {details}" + ) + return res.json() + + def get_experiment_datasets(self, training_exp_id: int) -> dict: + """Retrieves all approved datasets for a given training_exp + + Args: + benchmark_id (int): benchmark ID to retrieve results from + + Returns: + dict: dictionary with the contents of each result in the specified benchmark + """ + results = self.__get_list( + f"{self.server_url}/training/{training_exp_id}/datasets" + ) + results = [dataset["id"] for dataset in results] + + return results + + def get_experiment_aggregator(self, training_exp_id: int) -> dict: + """Retrieves the experiment aggregator + + Args: + benchmark_id (int): benchmark ID to retrieve results from + + Returns: + dict: dictionary with the contents of each result in the specified benchmark + """ + + res = self.__auth_get( + f"{self.server_url}/training/{training_exp_id}/aggregator" + ) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"There was a problem when retrieving the aggregator: {details}" + ) + return res.json() + + def set_experiment_as_operational(self, training_exp_id: int) -> dict: + """lock experiment (set as operational) + + Args: + benchmark_id (int): benchmark ID to retrieve results from + + Returns: + dict: dictionary with the contents of each result in the specified benchmark + """ + + url = f"{self.server_url}/training/{training_exp_id}/" + res = self.__set_state(url, "OPERATION") + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError( + f"Could not set operational state for experiment {training_exp_id}: {details}" + ) + + def upload_aggregator(self, aggregator_dict: dict) -> int: + """Uploads a new aggregator to the server. + + Args: + benchmark_dict (dict): benchmark_data to be uploaded + + Returns: + int: UID of newly created benchmark + """ + res = self.__auth_post(f"{self.server_url}/aggregators/", json=aggregator_dict) + if res.status_code != 201: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError(f"Could not upload aggregator: {details}") + return res.json() + + def get_aggregator(self, aggregator_id: int) -> dict: + """Retrieves the aggregator specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark + + Returns: + dict: benchmark specification + """ + res = self.__auth_get(f"{self.server_url}/aggregators/{aggregator_id}") + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"the specified aggregator doesn't exist: {details}" + ) + return res.json() + + def associate_aggregator(self, aggregator_id: int, training_exp_id: int, csr: str): + """Create a aggregator experiment association + + Args: + data_uid (int): Registered dataset UID + benchmark_uid (int): Benchmark UID + metadata (dict, optional): Additional metadata. Defaults to {}. + """ + data = { + "aggregator": aggregator_id, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + "signing_request": csr, + } + res = self.__auth_post( + f"{self.server_url}/aggregators/training_experiments/", json=data + ) + if res.status_code != 201: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError( + f"Could not associate aggregator to training_exp: {details}" + ) + + def set_aggregator_association_approval( + self, training_exp_id: int, aggregator_id: int, status: str + ): + """Approves a aggregator association + + Args: + dataset_uid (int): Dataset UID + benchmark_uid (int): Benchmark UID + status (str): Approval status to set for the association + """ + url = f"{self.server_url}/aggregators/{aggregator_id}/training_experiments/{training_exp_id}/" + res = self.__set_approval_status(url, status) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError( + "Could not approve association between aggregator" + f"{aggregator_id} and training_exp {training_exp_id}: {details}" + ) + + def get_aggregator_association( + self, training_exp_id: int, aggregator_id: int + ) -> dict: + """Retrieves the aggregator association specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark + + Returns: + dict: benchmark specification + """ + url = f"{self.server_url}/aggregators/{aggregator_id}/training_experiments/{training_exp_id}/" + res = self.__auth_get(url) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"There was a problem when retrieving the association: {details}" + ) + return res.json() + + def associate_training_dset(self, data_uid: int, training_exp_id: int, csr: str): + """Create a Dataset experiment association + + Args: + data_uid (int): Registered dataset UID + benchmark_uid (int): Benchmark UID + metadata (dict, optional): Additional metadata. Defaults to {}. + """ + data = { + "dataset": data_uid, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + "signing_request": csr, + } + res = self.__auth_post( + f"{self.server_url}/datasets/training_experiments/", json=data + ) + if res.status_code != 201: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError( + f"Could not associate dataset to training_exp: {details}" + ) + + def set_training_dataset_association_approval( + self, training_exp_id: int, dataset_uid: int, status: str + ): + """Approves a trainining dataset association + + Args: + dataset_uid (int): Dataset UID + benchmark_uid (int): Benchmark UID + status (str): Approval status to set for the association + """ + url = f"{self.server_url}/datasets/{dataset_uid}/training_experiments/{training_exp_id}/" + res = self.__set_approval_status(url, status) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError( + "Could not approve association between dataset" + f"{dataset_uid} and training_exp {training_exp_id}: {details}" + ) + + def get_training_dataset_association( + self, training_exp_id: int, dataset_uid: int + ) -> dict: + """Retrieves the training dataset association specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark + + Returns: + dict: benchmark specification + """ + url = f"{self.server_url}/datasets/{dataset_uid}/training_experiments/{training_exp_id}/" + res = self.__auth_get(url) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError( + f"There was a problem when retrieving the association: {details}" + ) + return res.json() + + def get_aggregators(self) -> List[dict]: + """Retrieves all aggregators + + Returns: + List[dict]: List of aggregators + """ + aggregators = self.__get_list(f"{self.server_url}/aggregators") + return aggregators + + def get_user_aggregators(self) -> dict: + """Retrieves all aggregators registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + aggregators = self.__get_list(f"{self.server_url}/me/aggregators/") + return aggregators + + def get_training_exps(self) -> List[dict]: + """Retrieves all training_exps + + Returns: + List[dict]: List of training_exps + """ + training_exps = self.__get_list(f"{self.server_url}/training") + return training_exps + + def get_user_training_exps(self) -> dict: + """Retrieves all training_exps registered by the user + + Returns: + dict: dictionary with the contents of each result registration query + """ + training_exps = self.__get_list(f"{self.server_url}/me/training/") + return training_exps + + def get_training_datasets_associations(self) -> List[dict]: + """Get all training dataset associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + assocs = self.__get_list( + f"{self.server_url}/me/datasets/training_associations/" + ) + return assocs + + def get_aggregators_associations(self) -> List[dict]: + """Get all aggregator associations related to the current user + + Returns: + List[dict]: List containing all associations information + """ + assocs = self.__get_list(f"{self.server_url}/me/aggregators/associations/") + return assocs diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 0391e045e..af2b81df9 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -60,6 +60,8 @@ results_folder = "results" predictions_folder = "predictions" tests_folder = "tests" +training_folder = "training" +aggregators_folder = "aggregators" default_base_storage = str(Path.home().resolve() / ".medperf") @@ -112,6 +114,14 @@ "base": default_base_storage, "name": tests_folder, }, + "training_folder": { + "base": default_base_storage, + "name": training_folder, + }, + "aggregators_folder": { + "base": default_base_storage, + "name": aggregators_folder + }, } root_folders = [ @@ -140,6 +150,12 @@ log_file = "medperf.log" tarball_filename = "tmp.tar.gz" demo_dset_paths_file = "paths.yaml" +training_exps_filename = "training-info.yaml" +training_exp_cols_filename = "cols.yaml" +agg_cert_folder = "agg_cert" +data_cert_folder = "data_cert" +ca_cert_folder = "ca_cert" +network_config_filename = "network.yaml" # MLCube assets conventions cube_filename = "mlcube.yaml" diff --git a/cli/medperf/cryptography/__init__.py b/cli/medperf/cryptography/__init__.py new file mode 100644 index 000000000..b3f394d12 --- /dev/null +++ b/cli/medperf/cryptography/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.cryptography package.""" diff --git a/cli/medperf/cryptography/ca.py b/cli/medperf/cryptography/ca.py new file mode 100644 index 000000000..d651919a4 --- /dev/null +++ b/cli/medperf/cryptography/ca.py @@ -0,0 +1,150 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cryptography CA utilities.""" + +import datetime +import uuid +from typing import Tuple + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509.base import Certificate +from cryptography.x509.base import CertificateSigningRequest +from cryptography.x509.extensions import ExtensionNotFound +from cryptography.x509.name import Name +from cryptography.x509.oid import ExtensionOID +from cryptography.x509.oid import NameOID + + +def generate_root_cert( + common_name: str = "Simple Root CA", days_to_expiration: int = 365 +) -> Tuple[RSAPrivateKey, Certificate]: + """Generate_root_certificate.""" + now = datetime.datetime.utcnow() + expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) + + # Generate private key + root_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=3072, backend=default_backend() + ) + + # Generate public key + root_public_key = root_private_key.public_key() + builder = x509.CertificateBuilder() + subject = x509.Name( + [ + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Root CA"), + ] + ) + issuer = subject + builder = builder.subject_name(subject) + builder = builder.issuer_name(issuer) + + builder = builder.not_valid_before(now) + builder = builder.not_valid_after(now + expiration_delta) + builder = builder.serial_number(int(uuid.uuid4())) + builder = builder.public_key(root_public_key) + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + + # Sign the CSR + certificate = builder.sign( + private_key=root_private_key, + algorithm=hashes.SHA384(), + backend=default_backend(), + ) + + return root_private_key, certificate + + +def generate_signing_csr() -> Tuple[RSAPrivateKey, CertificateSigningRequest]: + """Generate signing CSR.""" + # Generate private key + signing_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=3072, backend=default_backend() + ) + + builder = x509.CertificateSigningRequestBuilder() + subject = x509.Name( + [ + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), + x509.NameAttribute(NameOID.COMMON_NAME, "Simple Signing CA"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Signing CA"), + ] + ) + builder = builder.subject_name(subject) + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + + # Sign the CSR + csr = builder.sign( + private_key=signing_private_key, + algorithm=hashes.SHA384(), + backend=default_backend(), + ) + + return signing_private_key, csr + + +def sign_certificate( + csr: CertificateSigningRequest, + issuer_private_key: RSAPrivateKey, + issuer_name: Name, + days_to_expiration: int = 365, + ca: bool = False, +) -> Certificate: + """ + Sign the incoming CSR request. + + Args: + csr : Certificate Signing Request object + issuer_private_key : Root CA private key if the request is for the signing + CA; Signing CA private key otherwise + issuer_name : x509 Name + days_to_expiration : int (365 days by default) + ca : Is this a certificate authority + """ + now = datetime.datetime.utcnow() + expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) + + builder = x509.CertificateBuilder() + builder = builder.subject_name(csr.subject) + builder = builder.issuer_name(issuer_name) + builder = builder.not_valid_before(now) + builder = builder.not_valid_after(now + expiration_delta) + builder = builder.serial_number(int(uuid.uuid4())) + builder = builder.public_key(csr.public_key()) + builder = builder.add_extension( + x509.BasicConstraints(ca=ca, path_length=None), + critical=True, + ) + try: + builder = builder.add_extension( + csr.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ).value, + critical=False, + ) + except ExtensionNotFound: + pass # Might not have alternative name + + signed_cert = builder.sign( + private_key=issuer_private_key, + algorithm=hashes.SHA384(), + backend=default_backend(), + ) + return signed_cert diff --git a/cli/medperf/cryptography/io.py b/cli/medperf/cryptography/io.py new file mode 100644 index 000000000..52bfc5e95 --- /dev/null +++ b/cli/medperf/cryptography/io.py @@ -0,0 +1,129 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cryptography IO utilities.""" + +import os +from hashlib import sha384 +from pathlib import Path +from typing import Tuple + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.x509.base import Certificate +from cryptography.x509.base import CertificateSigningRequest + + +def read_key(path: Path) -> RSAPrivateKey: + """ + Read private key. + + Args: + path : Path (pathlib) + + Returns: + private_key + """ + with open(path, 'rb') as f: + pem_data = f.read() + + signing_key = load_pem_private_key(pem_data, password=None) + # TODO: replace assert with exception / sys.exit + assert (isinstance(signing_key, rsa.RSAPrivateKey)) + return signing_key + + +def write_key(key: RSAPrivateKey, path: Path) -> None: + """ + Write private key. + + Args: + key : RSA private key object + path : Path (pathlib) + + """ + def key_opener(path, flags): + return os.open(path, flags, mode=0o600) + + with open(path, 'wb', opener=key_opener) as f: + f.write(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + )) + + +def read_crt(path: Path) -> Certificate: + """ + Read signed TLS certificate. + + Args: + path : Path (pathlib) + + Returns: + Cryptography TLS Certificate object + """ + with open(path, 'rb') as f: + pem_data = f.read() + + certificate = x509.load_pem_x509_certificate(pem_data) + # TODO: replace assert with exception / sys.exit + assert (isinstance(certificate, x509.Certificate)) + return certificate + + +def write_crt(certificate: Certificate, path: Path) -> None: + """ + Write cryptography certificate / csr. + + Args: + certificate : cryptography csr / certificate object + path : Path (pathlib) + + Returns: + Cryptography TLS Certificate object + """ + with open(path, 'wb') as f: + f.write(certificate.public_bytes( + encoding=serialization.Encoding.PEM, + )) + + +def read_csr(path: Path) -> Tuple[CertificateSigningRequest, str]: + """ + Read certificate signing request. + + Args: + path : Path (pathlib) + + Returns: + Cryptography CSR object + """ + with open(path, 'rb') as f: + pem_data = f.read() + + csr = x509.load_pem_x509_csr(pem_data) + # TODO: replace assert with exception / sys.exit + assert (isinstance(csr, x509.CertificateSigningRequest)) + return csr, get_csr_hash(csr) + + +def get_csr_hash(certificate: CertificateSigningRequest) -> str: + """ + Get hash of cryptography certificate. + + Args: + certificate : Cryptography CSR object + + Returns: + Hash of cryptography certificate / csr + """ + hasher = sha384() + encoded_bytes = certificate.public_bytes( + encoding=serialization.Encoding.PEM, + ) + hasher.update(encoded_bytes) + return hasher.hexdigest() diff --git a/cli/medperf/cryptography/participant.py b/cli/medperf/cryptography/participant.py new file mode 100644 index 000000000..d6e94712b --- /dev/null +++ b/cli/medperf/cryptography/participant.py @@ -0,0 +1,72 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cryptography participant utilities.""" +from typing import Tuple + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509.base import CertificateSigningRequest +from cryptography.x509.oid import NameOID + + +def generate_csr(common_name: str, + server: bool = False) -> Tuple[RSAPrivateKey, CertificateSigningRequest]: + """Issue certificate signing request for server and client.""" + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=3072, + backend=default_backend() + ) + + builder = x509.CertificateSigningRequestBuilder() + subject = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ]) + builder = builder.subject_name(subject) + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True, + ) + if server: + builder = builder.add_extension( + x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.SERVER_AUTH]), + critical=True + ) + + else: + builder = builder.add_extension( + x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.CLIENT_AUTH]), + critical=True + ) + + builder = builder.add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + data_encipherment=False, + key_agreement=False, + content_commitment=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False + ), + critical=True + ) + + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName(common_name)]), + critical=False + ) + + # Sign the CSR + csr = builder.sign( + private_key=private_key, algorithm=hashes.SHA384(), + backend=default_backend() + ) + + return private_key, csr diff --git a/cli/medperf/cryptography/utils.py b/cli/medperf/cryptography/utils.py new file mode 100644 index 000000000..03f9eb940 --- /dev/null +++ b/cli/medperf/cryptography/utils.py @@ -0,0 +1,14 @@ +from cryptography.hazmat.primitives import serialization +from cryptography import x509 + + +def cert_to_str(cert): + return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") + + +def str_to_cert(cert_str): + return x509.load_pem_x509_certificate(cert_str.encode("utf-8")) + + +def str_to_csr(csr_str): + return x509.load_pem_x509_csr(csr_str.encode("utf-8")) diff --git a/cli/medperf/entities/aggregator.py b/cli/medperf/entities/aggregator.py new file mode 100644 index 000000000..0127cb410 --- /dev/null +++ b/cli/medperf/entities/aggregator.py @@ -0,0 +1,234 @@ +import os +import yaml +import logging +import hashlib +from typing import List, Optional, Union + +from medperf.utils import storage_path +from medperf.entities.interface import Entity, Uploadable +from medperf.entities.schemas import MedperfSchema +from medperf.exceptions import ( + InvalidArgumentError, + MedperfException, + CommunicationRetrievalError, +) +import medperf.config as config +from medperf.account_management import get_medperf_user_data + + +class Aggregator(Entity, MedperfSchema, Uploadable): + """ + Class representing a compatibility test report entry + + A test report is comprised of the components of a test execution: + - data used, which can be: + - a demo aggregator url and its hash, or + - a raw data path and its labels path, or + - a prepared aggregator uid + - Data preparation cube if the data used was not already prepared + - model cube + - evaluator cube + - results + """ + + server_config: Optional[dict] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.address = self.server_config["address"] + self.port = self.server_config["port"] + self.generated_uid = self.__generate_uid() + + path = storage_path(config.aggregator_storage) + if self.id: + path = os.path.join(path, str(self.id)) + else: + path = os.path.join(path, self.generated_uid) + + self.path = path + self.network_config_path = os.path.join(path, config.network_config_filename) + + def __generate_uid(self): + """A helper that generates a unique hash for a server config.""" + + params = str(self.server_config) + return hashlib.sha1(params.encode()).hexdigest() + + def todict(self): + return self.extended_dict() + + @classmethod + def all(cls, local_only: bool = False, filters: dict = {}) -> List["Aggregator"]: + """Gets and creates instances of all the locally prepared aggregators + + Args: + local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + + Returns: + List[Aggregator]: a list of Aggregator instances. + """ + logging.info("Retrieving all aggregators") + aggs = [] + if not local_only: + aggs = cls.__remote_all(filters=filters) + + remote_uids = set([agg.id for agg in aggs]) + + local_aggs = cls.__local_all() + + aggs += [agg for agg in local_aggs if agg.id not in remote_uids] + + return aggs + + @classmethod + def __remote_all(cls, filters: dict) -> List["Aggregator"]: + aggs = [] + try: + comms_fn = cls.__remote_prefilter(filters) + aggs_meta = comms_fn() + aggs = [cls(**meta) for meta in aggs_meta] + except CommunicationRetrievalError: + msg = "Couldn't retrieve all aggregators from the server" + logging.warning(msg) + + return aggs + + @classmethod + def __remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_aggregators + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_aggregators + return comms_fn + + @classmethod + def __local_all(cls) -> List["Aggregator"]: + aggs = [] + aggregator_storage = storage_path(config.aggregator_storage) + try: + uids = next(os.walk(aggregator_storage))[1] + except StopIteration: + msg = "Couldn't iterate over the aggregator directory" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + local_meta = cls.__get_local_dict(uid) + agg = cls(**local_meta) + aggs.append(agg) + + return aggs + + @classmethod + def get(cls, agg_uid: Union[str, int], local_only: bool = False) -> "Aggregator": + """Retrieves and creates a Aggregator instance from the comms instance. + If the aggregator is present in the user's machine then it retrieves it from there. + + Args: + agg_uid (str): server UID of the aggregator + + Returns: + Aggregator: Specified Aggregator Instance + """ + if not str(agg_uid).isdigit() or local_only: + return cls.__local_get(agg_uid) + + try: + return cls.__remote_get(agg_uid) + except CommunicationRetrievalError: + logging.warning(f"Getting Aggregator {agg_uid} from comms failed") + logging.info(f"Looking for aggregator {agg_uid} locally") + return cls.__local_get(agg_uid) + + @classmethod + def __remote_get(cls, agg_uid: int) -> "Aggregator": + """Retrieves and creates a Aggregator instance from the comms instance. + If the aggregator is present in the user's machine then it retrieves it from there. + + Args: + agg_uid (str): server UID of the aggregator + + Returns: + Aggregator: Specified Aggregator Instance + """ + logging.debug(f"Retrieving aggregator {agg_uid} remotely") + meta = config.comms.get_aggregator(agg_uid) + aggregator = cls(**meta) + aggregator.write() + return aggregator + + @classmethod + def __local_get(cls, agg_uid: Union[str, int]) -> "Aggregator": + """Retrieves and creates a Aggregator instance from the comms instance. + If the aggregator is present in the user's machine then it retrieves it from there. + + Args: + agg_uid (str): server UID of the aggregator + + Returns: + Aggregator: Specified Aggregator Instance + """ + logging.debug(f"Retrieving aggregator {agg_uid} locally") + local_meta = cls.__get_local_dict(agg_uid) + aggregator = cls(**local_meta) + return aggregator + + def write(self): + logging.info(f"Updating registration information for aggregator: {self.id}") + logging.debug(f"registration information: {self.todict()}") + regfile = os.path.join(self.path, config.reg_file) + os.makedirs(self.path, exist_ok=True) + with open(regfile, "w") as f: + yaml.dump(self.todict(), f) + + # write network config + with open(self.network_config_path, "w") as f: + yaml.dump(self.server_config, f) + + return regfile + + def upload(self): + """Uploads the registration information to the comms. + + Args: + comms (Comms): Instance of the comms interface. + """ + if self.for_test: + raise InvalidArgumentError("Cannot upload test aggregators.") + aggregator_dict = self.todict() + updated_aggregator_dict = config.comms.upload_aggregator(aggregator_dict) + return updated_aggregator_dict + + @classmethod + def __get_local_dict(cls, aggregator_uid): + aggregator_path = os.path.join( + storage_path(config.aggregator_storage), str(aggregator_uid) + ) + regfile = os.path.join(aggregator_path, config.reg_file) + if not os.path.exists(regfile): + raise InvalidArgumentError( + "The requested aggregator information could not be found locally" + ) + with open(regfile, "r") as f: + reg = yaml.safe_load(f) + return reg + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Generated Hash": self.generated_uid, + "Address": self.server_config["address"], + "Port": self.server_config["port"], + "Created At": self.created_at, + "Registered": self.is_registered, + } diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 6e6f65677..b748e17f4 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -298,7 +298,7 @@ def download_run_files(self): def run( self, task: str, - output_logs: str = None, + output_logs_file: str = None, string_params: Dict[str, str] = {}, timeout: int = None, read_protected_input: bool = True, @@ -316,7 +316,9 @@ def run( """ kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" - cmd += f" --mlcube={self.cube_path} --task={task} --platform={config.platform} --network=none" + cmd += f" --mlcube={self.cube_path} --task={task} --platform={config.platform}" + if task not in ["train", "start_aggregator"]: + cmd += " --network=none" if config.gpus is not None: cmd += f" --gpus={config.gpus}" if read_protected_input: @@ -367,8 +369,8 @@ def run( proc = proc_wrapper.proc proc_out = combine_proc_sp_text(proc) - if output_logs is not None: - with open(output_logs, "w") as f: + if output_logs_file is not None: + with open(output_logs_file, "w") as f: f.write(proc_out) if proc.exitstatus != 0: raise ExecutionError("There was an error while executing the cube") diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py new file mode 100644 index 000000000..46764845b --- /dev/null +++ b/cli/medperf/entities/training_exp.py @@ -0,0 +1,316 @@ +import os +from medperf.exceptions import MedperfException +import yaml +import logging +from typing import List, Optional, Union +from pydantic import HttpUrl, Field, validator + +import medperf.config as config +from medperf.entities.interface import Entity, Uploadable +from medperf.utils import get_dataset_common_name, storage_path +from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema +from medperf.account_management import get_medperf_user_data, read_user_account + + +class TrainingExp( + Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema +): + """ + Class representing a TrainingExp + + a training_exp is a bundle of assets that enables quantitative + measurement of the performance of AI models for a specific + clinical problem. A TrainingExp instance contains information + regarding how to prepare datasets for execution, as well as + what models to run and how to evaluate them. + """ + + description: Optional[str] = Field(None, max_length=20) + docs_url: Optional[HttpUrl] + demo_dataset_tarball_url: Optional[str] + demo_dataset_tarball_hash: Optional[str] + demo_dataset_generated_uid: Optional[str] + data_preparation_mlcube: int + fl_mlcube: int + public_key: Optional[str] + datasets: List[int] = None + metadata: dict = {} + user_metadata: dict = {} + state: str = "DEVELOPMENT" + + @validator("datasets", pre=True, always=True) + def set_default_datasets_value(cls, value, values, **kwargs): + if not value: + # Empty or None value assigned + return [] + return value + + def __init__(self, *args, **kwargs): + """Creates a new training_exp instance + + Args: + training_exp_desc (Union[dict, TrainingExpModel]): TrainingExp instance description + """ + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + path = storage_path(config.training_exps_storage) + if self.id: + path = os.path.join(path, str(self.id)) + else: + path = os.path.join(path, self.generated_uid) + self.path = path + self.cert_path = os.path.join(path, config.ca_cert_folder) + self.cols_path = os.path.join(path, config.training_exp_cols_filename) + + @classmethod + def all(cls, local_only: bool = False, filters: dict = {}) -> List["TrainingExp"]: + """Gets and creates instances of all retrievable training_exps + + Args: + local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + + Returns: + List[TrainingExp]: a list of TrainingExp instances. + """ + logging.info("Retrieving all training_exps") + training_exps = [] + + if not local_only: + training_exps = cls.__remote_all(filters=filters) + + remote_uids = set([training_exp.id for training_exp in training_exps]) + + local_training_exps = cls.__local_all() + + training_exps += [ + training_exp + for training_exp in local_training_exps + if training_exp.id not in remote_uids + ] + + return training_exps + + @classmethod + def __remote_all(cls, filters: dict) -> List["TrainingExp"]: + training_exps = [] + try: + comms_fn = cls.__remote_prefilter(filters) + training_exps_meta = comms_fn() + for training_exp_meta in training_exps_meta: + # Loading all related models for all training_exps could be expensive. + # Most probably not necessary when getting all training_exps. + # If associated models for a training_exp are needed then use TrainingExp.get() + training_exp_meta["datasets"] = [] + training_exps = [cls(**meta) for meta in training_exps_meta] + except CommunicationRetrievalError: + msg = "Couldn't retrieve all training_exps from the server" + logging.warning(msg) + + return training_exps + + @classmethod + def __remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_training_exps + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_training_exps + return comms_fn + + @classmethod + def __local_all(cls) -> List["TrainingExp"]: + training_exps = [] + training_exps_storage = storage_path(config.training_exps_storage) + try: + uids = next(os.walk(training_exps_storage))[1] + except StopIteration: + msg = "Couldn't iterate over training_exps directory" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + meta = cls.__get_local_dict(uid) + training_exp = cls(**meta) + training_exps.append(training_exp) + + return training_exps + + @classmethod + def get( + cls, training_exp_uid: Union[str, int], local_only: bool = False + ) -> "TrainingExp": + """Retrieves and creates a TrainingExp instance from the server. + If training_exp already exists in the platform then retrieve that + version. + + Args: + training_exp_uid (str): UID of the training_exp. + comms (Comms): Instance of a communication interface. + + Returns: + TrainingExp: a TrainingExp instance with the retrieved data. + """ + + if not str(training_exp_uid).isdigit() or local_only: + return cls.__local_get(training_exp_uid) + + try: + return cls.__remote_get(training_exp_uid) + except CommunicationRetrievalError: + logging.warning(f"Getting TrainingExp {training_exp_uid} from comms failed") + logging.info(f"Looking for training_exp {training_exp_uid} locally") + return cls.__local_get(training_exp_uid) + + @classmethod + def __remote_get(cls, training_exp_uid: int) -> "TrainingExp": + """Retrieves and creates a Dataset instance from the comms instance. + If the dataset is present in the user's machine then it retrieves it from there. + + Args: + dset_uid (str): server UID of the dataset + + Returns: + Dataset: Specified Dataset Instance + """ + logging.debug(f"Retrieving training_exp {training_exp_uid} remotely") + training_exp_dict = config.comms.get_training_exp(training_exp_uid) + datasets = cls.get_datasets_uids(training_exp_uid) + training_exp_dict["datasets"] = datasets + training_exp = cls(**training_exp_dict) + training_exp.write() + return training_exp + + @classmethod + def __local_get(cls, training_exp_uid: Union[str, int]) -> "TrainingExp": + """Retrieves and creates a Dataset instance from the comms instance. + If the dataset is present in the user's machine then it retrieves it from there. + + Args: + dset_uid (str): server UID of the dataset + + Returns: + Dataset: Specified Dataset Instance + """ + logging.debug(f"Retrieving training_exp {training_exp_uid} locally") + training_exp_dict = cls.__get_local_dict(training_exp_uid) + training_exp = cls(**training_exp_dict) + return training_exp + + @classmethod + def __get_local_dict(cls, training_exp_uid) -> dict: + """Retrieves a local training_exp information + + Args: + training_exp_uid (str): uid of the local training_exp + + Returns: + dict: information of the training_exp + """ + logging.info(f"Retrieving training_exp {training_exp_uid} from local storage") + storage = storage_path(config.training_exps_storage) + training_exp_storage = os.path.join(storage, str(training_exp_uid)) + training_exp_file = os.path.join( + training_exp_storage, config.training_exps_filename + ) + if not os.path.exists(training_exp_file): + raise InvalidArgumentError( + "No training_exp with the given uid could be found" + ) + with open(training_exp_file, "r") as f: + data = yaml.safe_load(f) + + return data + + @classmethod + def get_datasets_uids(cls, training_exp_uid: int) -> List[int]: + """Retrieves the list of models associated to the training_exp + + Args: + training_exp_uid (int): UID of the training_exp. + comms (Comms): Instance of the communications interface. + + Returns: + List[int]: List of mlcube uids + """ + return config.comms.get_experiment_datasets(training_exp_uid) + + def todict(self) -> dict: + """Dictionary representation of the training_exp instance + + Returns: + dict: Dictionary containing training_exp information + """ + return self.extended_dict() + + def write(self) -> str: + """Writes the training_exp into disk + + Args: + filename (str, optional): name of the file. Defaults to config.training_exps_filename. + + Returns: + str: path to the created training_exp file + """ + data = self.todict() + training_exp_file = os.path.join(self.path, config.training_exps_filename) + if not os.path.exists(training_exp_file): + os.makedirs(self.path, exist_ok=True) + with open(training_exp_file, "w") as f: + yaml.dump(data, f) + + # write cert + os.makedirs(self.cert_path, exist_ok=True) + cert_file = os.path.join(self.cert_path, "cert.crt") + with open(cert_file, "w") as f: + f.write(self.public_key) + + # write cols + dataset_owners_emails = [""] * len( + self.datasets + ) # TODO (this will need some work) + # our medperf's user info endpoint is not public + # emails currently are not stored in medperf (auth0 only. in access tokens as well) + cols = [ + get_dataset_common_name(email, dataset_id, self.id) + for email, dataset_id in zip(dataset_owners_emails, self.datasets) + ] + with open(self.cols_path, "w") as f: + f.write("\n".join(cols)) + + return training_exp_file + + def upload(self): + """Uploads a training_exp to the server + + Args: + comms (Comms): communications entity to submit through + """ + if self.for_test: + raise InvalidArgumentError("Cannot upload test training_exps.") + body = self.todict() + updated_body = config.comms.upload_training_exp(body) + updated_body["datasets"] = body["datasets"] + return updated_body + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Description": self.description, + "Documentation": self.docs_url, + "Created At": self.created_at, + "FL MLCube": int(self.fl_mlcube), + "Associated Datasets": ",".join(map(str, self.datasets)), + "State": self.state, + "Registered": self.is_registered, + "Approval Status": self.approval_status, + } diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index 669d7dfd9..09df143d1 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -169,14 +169,14 @@ def test_cube_run_are_called_properly(mocker, setup): exp_model_call = call( task="infer", - output_logs=exp_model_logs_path, + output_logs_file=exp_model_logs_path, timeout=config.infer_timeout, data_path=INPUT_DATASET.data_path, output_path=exp_preds_path, ) exp_eval_call = call( task="evaluate", - output_logs=exp_metrics_logs_path, + output_logs_file=exp_metrics_logs_path, timeout=config.evaluate_timeout, predictions=exp_preds_path, labels=INPUT_DATASET.labels_path, diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index c054d64c1..fc5076823 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -21,6 +21,9 @@ from pexpect.exceptions import TIMEOUT from git import Repo, GitCommandError import medperf.config as config +from medperf.cryptography.participant import generate_csr +from medperf.cryptography.io import get_csr_hash, write_key +from medperf.cryptography.utils import cert_to_str from medperf.exceptions import ExecutionError, MedperfException, InvalidEntityError @@ -512,3 +515,48 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.proc.wait() # Return False to propagate exceptions, if any return False +def get_dataset_common_name(email, dataset_id, exp_id): + return f"{email}_d{dataset_id}_e{exp_id}".lower() + + +def generate_data_csr(email, data_uid, training_exp_id): + common_name = get_dataset_common_name(email, data_uid, training_exp_id) + private_key, csr = generate_csr(common_name, server=False) + + # store private key + target_folder = os.path.join( + config.training_exps_storage, + str(training_exp_id), + config.data_cert_folder, + str(data_uid), + ) + target_folder = storage_path(target_folder) + os.makedirs(target_folder, exist_ok=True) + target_path = os.path.join(target_folder, "key.key") + write_key(private_key, target_path) + + csr_hash = get_csr_hash(csr) + csr_str = cert_to_str(csr) + return csr_str, csr_hash + + +def generate_agg_csr(training_exp_id, agg_address, agg_id): + common_name = f"{agg_address}".lower() + private_key, csr = generate_csr(common_name, server=True) + + # store private key + target_folder = os.path.join( + config.training_exps_storage, + str(training_exp_id), + config.agg_cert_folder, + str(agg_id), + ) + target_folder = storage_path(target_folder) + os.makedirs(target_folder, exist_ok=True) + target_path = os.path.join(target_folder, "key.key") + write_key(private_key, target_path) + + csr_hash = get_csr_hash(csr) + csr_str = cert_to_str(csr) + + return csr_str, csr_hash diff --git a/examples/fl/mlcube/mlcube-cpu.yaml b/examples/fl/mlcube/mlcube-cpu.yaml new file mode 100644 index 000000000..a32d885cb --- /dev/null +++ b/examples/fl/mlcube/mlcube-cpu.yaml @@ -0,0 +1,40 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: hasan7/fltest:0.0.0-cpu + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile-CPU" + +tasks: + train: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + parameters_file: parameters.yaml + network_config: network.yaml + outputs: + output_logs: logs/ + start_aggregator: + parameters: + inputs: + input_weights: additional_files/init_weights + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + parameters_file: parameters.yaml + network_config: network.yaml + collaborators: cols.yaml + outputs: + output_logs: logs/ + output_weights: final_weights/ diff --git a/examples/fl/mlcube/mlcube-gpu.yaml b/examples/fl/mlcube/mlcube-gpu.yaml new file mode 100644 index 000000000..766c5b68c --- /dev/null +++ b/examples/fl/mlcube/mlcube-gpu.yaml @@ -0,0 +1,41 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 1 + +docker: + # Image name + image: hasan7/fltest:0.0.0-gpu + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile-GPU" + gpu_args: --gpus all + +tasks: + train: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + parameters_file: parameters.yaml + network_config: network.yaml + outputs: + output_logs: logs/ + start_aggregator: + parameters: + inputs: + input_weights: additional_files/init_weights + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + parameters_file: parameters.yaml + network_config: network.yaml + collaborators: cols.yaml + outputs: + output_logs: logs/ + output_weights: final_weights/ diff --git a/examples/fl/mlcube/workspace/network.yaml b/examples/fl/mlcube/workspace/network.yaml new file mode 100644 index 000000000..31a4c1466 --- /dev/null +++ b/examples/fl/mlcube/workspace/network.yaml @@ -0,0 +1,5 @@ +agg_addr: 104.197.235.200 +agg_port: 50273 + +address: 104.197.235.200 +port: 50273 \ No newline at end of file diff --git a/examples/fl/mlcube/workspace/parameters-cpu.yaml b/examples/fl/mlcube/workspace/parameters-cpu.yaml new file mode 100644 index 000000000..72d04d48f --- /dev/null +++ b/examples/fl/mlcube/workspace/parameters-cpu.yaml @@ -0,0 +1,137 @@ +plan: + aggregator: + settings: + best_state_path: save/fets_seg_test_best.pbuf + db_store_rounds: 2 + init_state_path: save/fets_seg_test_init.pbuf + last_state_path: save/fets_seg_test_last.pbuf + rounds_to_train: 2 + write_logs: true + template: openfl.component.Aggregator + assigner: + settings: + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner + collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator + compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline + data_loader: + settings: + feature_shape: + - 32 + - 32 + - 32 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper + network: + settings: + agg_addr: any + agg_port: any + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network + task_runner: + settings: + device: cpu + gandlf_config: + batch_size: 1 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + normalize: null + enable_padding: false + in_memory: true + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: dc + medcam_enabled: false + metrics: + - dice + model: + amp: true + architecture: unet + base_filters: 32 + batch_norm: false + class_list: + - 0 + - 1 + dimension: 3 + final_layer: sigmoid + ignore_label_validation: null + norm_type: instance + num_channels: 1 + nested_training: + testing: -5 + validation: -5 + num_epochs: 1 + optimizer: + type: adam + output_dir: . + parallel_compute_command: '' + patch_sampler: uniform + patch_size: + - 32 + - 32 + - 32 + patience: 1 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 1 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.14 + minimum: 0.0.13 + weighted_loss: true + train_csv: seg_test_train.csv + val_csv: seg_test_val.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner + tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_dice + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_dice + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_dice diff --git a/examples/fl/mlcube/workspace/parameters-gpu.yaml b/examples/fl/mlcube/workspace/parameters-gpu.yaml new file mode 100644 index 000000000..b29186193 --- /dev/null +++ b/examples/fl/mlcube/workspace/parameters-gpu.yaml @@ -0,0 +1,137 @@ +plan: + aggregator: + settings: + best_state_path: save/fets_seg_test_best.pbuf + db_store_rounds: 2 + init_state_path: save/fets_seg_test_init.pbuf + last_state_path: save/fets_seg_test_last.pbuf + rounds_to_train: 2 + write_logs: true + template: openfl.component.Aggregator + assigner: + settings: + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner + collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator + compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline + data_loader: + settings: + feature_shape: + - 32 + - 32 + - 32 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper + network: + settings: + agg_addr: any + agg_port: any + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network + task_runner: + settings: + device: cuda + gandlf_config: + batch_size: 1 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + normalize: null + enable_padding: false + in_memory: true + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: dc + medcam_enabled: false + metrics: + - dice + model: + amp: true + architecture: unet + base_filters: 32 + batch_norm: false + class_list: + - 0 + - 1 + dimension: 3 + final_layer: sigmoid + ignore_label_validation: null + norm_type: instance + num_channels: 1 + nested_training: + testing: -5 + validation: -5 + num_epochs: 1 + optimizer: + type: adam + output_dir: . + parallel_compute_command: '' + patch_sampler: uniform + patch_size: + - 32 + - 32 + - 32 + patience: 1 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 1 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.14 + minimum: 0.0.13 + weighted_loss: true + train_csv: seg_test_train.csv + val_csv: seg_test_val.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner + tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_dice + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_dice + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_dice diff --git a/examples/fl/mlcube/workspace/parameters-miccai.yaml b/examples/fl/mlcube/workspace/parameters-miccai.yaml new file mode 100644 index 000000000..a9ec969e5 --- /dev/null +++ b/examples/fl/mlcube/workspace/parameters-miccai.yaml @@ -0,0 +1,166 @@ +plan: + aggregator: + settings: + best_state_path: save/classification_best.pbuf + db_store_rounds: 2 + init_state_path: save/classification_init.pbuf + last_state_path: save/classification_last.pbuf + rounds_to_train: 3 + write_logs: true + template: openfl.component.Aggregator + assigner: + settings: + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner + collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator + compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline + data_loader: + settings: + feature_shape: + - 128 + - 128 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper + network: + settings: + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network + task_runner: + settings: + device: cpu + gandlf_config: + memory_save_mode: false # + batch_size: 16 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + resize: + - 128 + - 128 + enable_padding: false + grid_aggregator_overlap: crop + in_memory: false + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: cel + medcam_enabled: false + metrics: + accuracy: + average: weighted + mdmc_average: samplewise + multi_class: true + subset_accuracy: false + threshold: 0.5 + balanced_accuracy: None + classification_accuracy: None + f1: + average: weighted + f1: + average: weighted + mdmc_average: samplewise + multi_class: true + threshold: 0.5 + modality: rad + model: + amp: false + architecture: resnet18 + base_filters: 32 + batch_norm: true + class_list: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + dimension: 2 + final_layer: sigmoid + ignore_label_validation: None + n_channels: 3 + norm_type: batch + num_channels: 3 + save_at_every_epoch: false + type: torch + nested_training: + testing: 1 + validation: -5 + num_epochs: 5 + opt: adam + optimizer: + type: adam + output_dir: . + parallel_compute_command: "" + patch_sampler: uniform + patch_size: + - 128 + - 128 + - 1 + patience: 5 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 5 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_masks: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + step_size: 0.0002 + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.14 + minimum: 0.0.14 + weighted_loss: true + train_csv: train_path_full.csv + val_csv: val_path_full.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner + tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_accuracy + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_accuracy + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_accuracy diff --git a/examples/fl/project/Dockerfile-CPU b/examples/fl/project/Dockerfile-CPU new file mode 100644 index 000000000..270d72d8c --- /dev/null +++ b/examples/fl/project/Dockerfile-CPU @@ -0,0 +1,31 @@ +FROM local/openfl:local + +ENV GANDLF_VERSION 60c9d28aa5e1b951e75ed5646ac20d5790fe4317 +ENV FL_WORKSPACE /mlcube_project/fl_workspace +ENV LANG C.UTF-8 + +# install software requirements needed by GaNDLF +RUN apt-get update && apt-get upgrade -y && apt-get install -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +# install GaNDLF (cpu) +RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu && \ + pip install --no-cache-dir openvino-dev==2023.0.1 && \ + git clone https://github.com/mlcommons/GaNDLF.git && \ + cd GaNDLF && git checkout $GANDLF_VERSION && \ + pip install --no-cache-dir -e . + + +# install workspace requirements +COPY ./fl_workspace/requirements.txt $FL_WORKSPACE/requirements.txt +RUN pip install --no-cache-dir -r $FL_WORKSPACE/requirements.txt + +# START hotfix: patch gandlf runner +RUN rm /openfl/openfl/federated/task/runner_gandlf.py +COPY ./hotfix.py /openfl/openfl/federated/task/runner_gandlf.py +RUN pip install --no-cache-dir -e /openfl/ +# END hotfix + +# Copy mlcube workspace +COPY . /mlcube_project + +ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/project/Dockerfile-GPU b/examples/fl/project/Dockerfile-GPU new file mode 100644 index 000000000..ecdf14622 --- /dev/null +++ b/examples/fl/project/Dockerfile-GPU @@ -0,0 +1,25 @@ +FROM local/openfl:local + +ENV GANDLF_VERSION 60c9d28aa5e1b951e75ed5646ac20d5790fe4317 +ENV FL_WORKSPACE /mlcube_project/fl_workspace +ENV LANG C.UTF-8 + +# install software requirements needed by GaNDLF +RUN apt-get update && apt-get upgrade -y && apt-get install -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +# install GaNDLF (cpu) +RUN pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 && \ + pip install openvino-dev==2023.0.1 && \ + git clone https://github.com/mlcommons/GaNDLF.git && \ + cd GaNDLF && git checkout $GANDLF_VERSION && \ + pip install -e . + + +# install workspace requirements +COPY ./fl_workspace/requirements.txt $FL_WORKSPACE/requirements.txt +RUN pip install --no-cache-dir -r $FL_WORKSPACE/requirements.txt + +# Copy mlcube workspace +COPY . /mlcube_project + +ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/project/README.md b/examples/fl/project/README.md new file mode 100644 index 000000000..ae4653520 --- /dev/null +++ b/examples/fl/project/README.md @@ -0,0 +1,43 @@ +# How to configure container build + +- List your pip requirements in `openfl_workspace/requirements.txt` +- Modify container base image and/or how GaNDLF is installed in `dockerfile` to have GPU support. +- Modify the GaNDLF hash (or simply copy your customized GaNDLF repo code) in `dockerfile` to use a custom GaNDLF version. + +Note: the plan to be attached to the container should be GaNDLF+OpenFL plan (I guess). + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +git checkout 11db12785c1a6a2d3c75656b38108443f88919e8 +cd openfl +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd ../mlcube +mlcube configure -Pdocker.build_strategy=always +``` + +# Expected assets to be attached + +(outdated) + +- cert folders: certificates of the aggregator/collaborator and the CA's public key +- collaborator list, FL plan, and the init weights for the aggregator +- training data for the collaborator + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# For later + +- To use a plan that doesn't depend on GaNDLF, maybe `openfl_workspace/src` should be prepopulated with the necessary code. diff --git a/examples/fl/project/aggregator.py b/examples/fl/project/aggregator.py new file mode 100644 index 000000000..61d583a12 --- /dev/null +++ b/examples/fl/project/aggregator.py @@ -0,0 +1,50 @@ +from utils import ( + get_aggregator_fqdn, + prepare_node_cert, + prepare_ca_cert, + prepare_plan, + prepare_cols_list, + prepare_init_weights, + get_weights_path, + WORKSPACE, +) + +import os +from subprocess import check_call +from distutils.dir_util import copy_tree + + +def start_aggregator( + input_weights, + parameters_file, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + network_config, + collaborators, +): + prepare_plan(parameters_file, network_config) + prepare_cols_list(collaborators) + prepare_init_weights(input_weights) + fqdn = get_aggregator_fqdn() + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}") + prepare_ca_cert(ca_cert_folder) + + check_call(["fx", "aggregator", "start"], cwd=WORKSPACE) + + # TODO: check how to copy logs during runtime. + # perhaps investigate overriding plan entries? + + # NOTE: logs and weights are copied, even if target folders are not empty + copy_tree(os.path.join(WORKSPACE, "logs"), output_logs) + + # NOTE: conversion fails since openfl needs sample data... + # weights_paths = get_weights_path() + # out_best = os.path.join(output_weights, "best") + # out_last = os.path.join(output_weights, "last") + # check_call( + # ["fx", "model", "save", "-i", weights_paths["best"], "-o", out_best], + # cwd=WORKSPACE, + # ) + copy_tree(os.path.join(WORKSPACE, "save"), output_weights) diff --git a/examples/fl/project/collaborator.py b/examples/fl/project/collaborator.py new file mode 100644 index 000000000..d49a89483 --- /dev/null +++ b/examples/fl/project/collaborator.py @@ -0,0 +1,29 @@ +from utils import ( + get_collaborator_cn, + prepare_node_cert, + prepare_ca_cert, + prepare_plan, + prepare_data, + WORKSPACE, +) +import os +from subprocess import check_call + + +def start_collaborator( + data_path, + labels_path, + parameters_file, + node_cert_folder, + ca_cert_folder, + network_config, + output_logs, # TODO: Is it needed? +): + prepare_plan(parameters_file, network_config) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}") + prepare_ca_cert(ca_cert_folder) + prepare_data(data_path, labels_path, cn) + + # set log files + check_call(["fx", "collaborator", "start", "-n", cn], cwd=WORKSPACE) diff --git a/examples/fl/project/fl_workspace/.workspace b/examples/fl/project/fl_workspace/.workspace new file mode 100644 index 000000000..3c2c5d08b --- /dev/null +++ b/examples/fl/project/fl_workspace/.workspace @@ -0,0 +1,2 @@ +current_plan_name: default + diff --git a/examples/fl/project/fl_workspace/plan/defaults b/examples/fl/project/fl_workspace/plan/defaults new file mode 100644 index 000000000..fb82f9c5b --- /dev/null +++ b/examples/fl/project/fl_workspace/plan/defaults @@ -0,0 +1,2 @@ +../../workspace/plan/defaults + diff --git a/examples/fl/project/fl_workspace/requirements.txt b/examples/fl/project/fl_workspace/requirements.txt new file mode 100644 index 000000000..709016a50 --- /dev/null +++ b/examples/fl/project/fl_workspace/requirements.txt @@ -0,0 +1 @@ +onnx==1.13.0 diff --git a/examples/fl/project/hotfix.py b/examples/fl/project/hotfix.py new file mode 100644 index 000000000..bbcc692a6 --- /dev/null +++ b/examples/fl/project/hotfix.py @@ -0,0 +1,667 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""GaNDLFTaskRunner module.""" + +from copy import deepcopy + +import numpy as np +import os +import torch as pt +from typing import Union +import yaml + +from openfl.utilities.split import split_tensor_dict_for_holdouts +from openfl.utilities import TensorKey + +from .runner import TaskRunner + +from GANDLF.compute.generic import create_pytorch_objects +from GANDLF.compute.training_loop import train_network +from GANDLF.compute.forward_pass import validate_network + + +class GaNDLFTaskRunner(TaskRunner): + """GaNDLF Model class for Federated Learning.""" + + def __init__( + self, + gandlf_config: Union[str, dict] = None, + device: str = None, + **kwargs + ): + """Initialize. + Args: + device (string): Compute device (default="cpu") + **kwargs: Additional parameters to pass to the functions + """ + super().__init__(**kwargs) + + # allow pass-through of a gandlf config as a file or a dict + + train_csv = self.data_loader.train_csv + val_csv = self.data_loader.val_csv + + if isinstance(gandlf_config, str) and os.path.exists(gandlf_config): + gandlf_config = yaml.safe_load(open(gandlf_config, "r")) + + ( + model, + optimizer, + train_loader, + val_loader, + scheduler, + params, + ) = create_pytorch_objects( + gandlf_config, train_csv=train_csv, val_csv=val_csv, device=device + ) + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.params = params + self.device = device + + # pass the actual dataloaders to the wrapper loader + self.data_loader.set_dataloaders(train_loader, val_loader) + + self.training_round_completed = False + + self.required_tensorkeys_for_function = {} + + # FIXME: why isn't this initial call in runner_pt? + self.initialize_tensorkeys_for_functions(with_opt_vars=False) + + # overwrite attribute to account for one optimizer param (in every + # child model that does not overwrite get and set tensordict) that is + # not a numpy array + self.tensor_dict_split_fn_kwargs.update({ + 'holdout_tensor_names': ['__opt_state_needed'] + }) + + def rebuild_model(self, round_num, input_tensor_dict, validation=False): + """ + Parse tensor names and update weights of model. Handles the optimizer treatment. + Returns: + None + """ + + if self.opt_treatment == 'RESET': + self.reset_opt_vars() + self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) + elif (self.training_round_completed + and self.opt_treatment == 'CONTINUE_GLOBAL' and not validation): + self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) + else: + self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) + + def validate(self, col_name, round_num, input_tensor_dict, + use_tqdm=False, **kwargs): + """Validate. + Run validation of the model on the local data. + Args: + col_name: Name of the collaborator + round_num: What round is it + input_tensor_dict: Required input tensors (for model) + use_tqdm (bool): Use tqdm to print a progress bar (Default=True) + kwargs: Key word arguments passed to GaNDLF main_run + Returns: + global_output_dict: Tensors to send back to the aggregator + local_output_dict: Tensors to maintain in the local TensorDB + """ + self.rebuild_model(round_num, input_tensor_dict, validation=True) + self.model.eval() + + epoch_valid_loss, epoch_valid_metric = validate_network(self.model, + self.data_loader.val_dataloader, + self.scheduler, + self.params, + round_num, + mode="validation") + + self.logger.info(epoch_valid_loss) + self.logger.info(epoch_valid_metric) + + origin = col_name + suffix = 'validate' + if kwargs['apply'] == 'local': + suffix += '_local' + else: + suffix += '_agg' + tags = ('metric', suffix) + + output_tensor_dict = {} + valid_loss_tensor_key = TensorKey('valid_loss', origin, round_num, True, tags) + output_tensor_dict[valid_loss_tensor_key] = np.array(epoch_valid_loss) + for k, v in epoch_valid_metric.items(): + if isinstance(v, str) and "_" in v: + continue + tensor_key = TensorKey(f'valid_{k}', origin, round_num, True, tags) + output_tensor_dict[tensor_key] = np.array(v) + + # Empty list represents metrics that should only be stored locally + return output_tensor_dict, {} + + def train(self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs): + """Train batches. + Train the model on the requested number of batches. + Args: + col_name : Name of the collaborator + round_num : What round is it + input_tensor_dict : Required input tensors (for model) + use_tqdm (bool) : Use tqdm to print a progress bar (Default=True) + epochs : The number of epochs to train + crossfold_test : Whether or not to use cross fold trainval/test + to evaluate the quality of the model under fine tuning + (this uses a separate prameter to pass in the data and + config used) + crossfold_test_data_csv : Data csv used to define data used in crossfold test. + This csv does not itself define the folds, just + defines the total data to be used. + crossfold_val_n : number of folds to use for the train,val level + of the nested crossfold. + corssfold_test_n : number of folds to use for the trainval,test level + of the nested crossfold. + kwargs : Key word arguments passed to GaNDLF main_run + Returns: + global_output_dict : Tensors to send back to the aggregator + local_output_dict : Tensors to maintain in the local TensorDB + """ + self.rebuild_model(round_num, input_tensor_dict) + # set to "training" mode + self.model.train() + for epoch in range(epochs): + self.logger.info(f'Run {epoch} epoch of {round_num} round') + # FIXME: do we want to capture these in an array + # rather than simply taking the last value? + epoch_train_loss, epoch_train_metric = train_network(self.model, + self.data_loader.train_dataloader, + self.optimizer, + self.params) + + # output model tensors (Doesn't include TensorKey) + tensor_dict = self.get_tensor_dict(with_opt_vars=True) + + metric_dict = {'loss': epoch_train_loss} + for k, v in epoch_train_metric.items(): + if isinstance(v, str) and "_" in v: + continue + metric_dict[f'train_{k}'] = v + + # Return global_tensor_dict, local_tensor_dict + # is this even pt-specific really? + global_tensor_dict, local_tensor_dict = create_tensorkey_dicts( + tensor_dict, + metric_dict, + col_name, + round_num, + self.logger, + self.tensor_dict_split_fn_kwargs, + ) + + # Update the required tensors if they need to be pulled from the + # aggregator + # TODO this logic can break if different collaborators have different + # roles between rounds. + # For example, if a collaborator only performs validation in the first + # round but training in the second, it has no way of knowing the + # optimizer state tensor names to request from the aggregator because + # these are only created after training occurs. A work around could + # involve doing a single epoch of training on random data to get the + # optimizer names, and then throwing away the model. + if self.opt_treatment == 'CONTINUE_GLOBAL': + self.initialize_tensorkeys_for_functions(with_opt_vars=True) + + # This will signal that the optimizer values are now present, + # and can be loaded when the model is rebuilt + self.training_round_completed = True + + # Return global_tensor_dict, local_tensor_dict + return global_tensor_dict, local_tensor_dict + + def get_tensor_dict(self, with_opt_vars=False): + """Return the tensor dictionary. + Args: + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + Returns: + dict: Tensor dictionary {**dict, **optimizer_dict} + """ + # Gets information regarding tensor model layers and optimizer state. + # FIXME: self.parameters() instead? Unclear if load_state_dict() or + # simple assignment is better + # for now, state dict gives us names which is good + # FIXME: do both and sanity check each time? + + state = to_cpu_numpy(self.model.state_dict()) + + if with_opt_vars: + opt_state = _get_optimizer_state(self.optimizer) + state = {**state, **opt_state} + + return state + + def _get_weights_names(self, with_opt_vars=False): + # Gets information regarding tensor model layers and optimizer state. + # FIXME: self.parameters() instead? Unclear if load_state_dict() or + # simple assignment is better + # for now, state dict gives us names which is good + # FIXME: do both and sanity check each time? + + state = self.model.state_dict().keys() + + if with_opt_vars: + opt_state = _get_optimizer_state(self.model.optimizer) + state += opt_state.keys() + + return state + + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + Args: + tensor_dict: The tensor dictionary + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + """ + set_pt_model_from_tensor_dict(self.model, tensor_dict, self.device, with_opt_vars) + + def get_optimizer(self): + """Get the optimizer of this instance.""" + return self.optimizer + + def get_required_tensorkeys_for_function(self, func_name, **kwargs): + """ + Get the required tensors for specified function that could be called \ + as part of a task. By default, this is just all of the layers and \ + optimizer of the model. + Args: + func_name + Returns: + list : [TensorKey] + """ + if func_name == 'validate': + local_model = 'apply=' + str(kwargs['apply']) + return self.required_tensorkeys_for_function[func_name][local_model] + else: + return self.required_tensorkeys_for_function[func_name] + + def initialize_tensorkeys_for_functions(self, with_opt_vars=False): + """Set the required tensors for all publicly accessible task methods. + By default, this is just all of the layers and optimizer of the model. + Custom tensors should be added to this function. + Args: + None + Returns: + None + """ + # TODO there should be a way to programmatically iterate through + # all of the methods in the class and declare the tensors. + # For now this is done manually + + output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + self.logger, output_model_dict, + **self.tensor_dict_split_fn_kwargs + ) + if not with_opt_vars: + global_model_dict_val = global_model_dict + local_model_dict_val = local_model_dict + else: + output_model_dict = self.get_tensor_dict(with_opt_vars=False) + global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( + self.logger, + output_model_dict, + **self.tensor_dict_split_fn_kwargs + ) + + self.required_tensorkeys_for_function['train'] = [ + TensorKey( + tensor_name, 'GLOBAL', 0, False, ('model',) + ) for tensor_name in global_model_dict + ] + self.required_tensorkeys_for_function['train'] += [ + TensorKey( + tensor_name, 'LOCAL', 0, False, ('model',) + ) for tensor_name in local_model_dict + ] + + # Validation may be performed on local or aggregated (global) model, + # so there is an extra lookup dimension for kwargs + self.required_tensorkeys_for_function['validate'] = {} + # TODO This is not stateless. The optimizer will not be + self.required_tensorkeys_for_function['validate']['apply=local'] = [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + for tensor_name in { + **global_model_dict_val, + **local_model_dict_val + }] + self.required_tensorkeys_for_function['validate']['apply=global'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + for tensor_name in global_model_dict_val + ] + self.required_tensorkeys_for_function['validate']['apply=global'] += [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + for tensor_name in local_model_dict_val + ] + + def load_native(self, filepath, model_state_dict_key='model_state_dict', + optimizer_state_dict_key='optimizer_state_dict', **kwargs): + """ + Load model and optimizer states from a pickled file specified by \ + filepath. model_/optimizer_state_dict args can be specified if needed. \ + Uses pt.load(). + Args: + filepath (string) : Path to pickle file created + by pt.save(). + model_state_dict_key (string) : key for model state dict + in pickled file. + optimizer_state_dict_key (string) : key for optimizer state dict + in picked file. + kwargs : unused + Returns: + None + """ + pickle_dict = pt.load(filepath) + self.model.load_state_dict(pickle_dict[model_state_dict_key]) + self.optimizer.load_state_dict(pickle_dict[optimizer_state_dict_key]) + + def save_native(self, filepath, model_state_dict_key='model_state_dict', + optimizer_state_dict_key='optimizer_state_dict', **kwargs): + """ + Save model and optimizer states in a picked file specified by the \ + filepath. model_/optimizer_state_dicts are stored in the keys provided. \ + Uses pt.save(). + Args: + filepath (string) : Path to pickle file to be + created by pt.save(). + model_state_dict_key (string) : key for model state dict + in pickled file. + optimizer_state_dict_key (string) : key for optimizer state + dict in picked file. + kwargs : unused + Returns: + None + """ + pickle_dict = { + model_state_dict_key: self.model.state_dict(), + optimizer_state_dict_key: self.optimizer.state_dict() + } + pt.save(pickle_dict, filepath) + + def reset_opt_vars(self): + """ + Reset optimizer variables. + Resets the optimizer variables + """ + pass + + +def create_tensorkey_dicts(tensor_dict, + metric_dict, + col_name, + round_num, + logger, + tensor_dict_split_fn_kwargs): + origin = col_name + tags = ('trained',) + output_metric_dict = {} + for k, v in metric_dict.items(): + tk = TensorKey(k, origin, round_num, True, ('metric',)) + output_metric_dict[tk] = np.array(v) + + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + logger, tensor_dict, **tensor_dict_split_fn_kwargs + ) + + # Create global tensorkeys + global_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num, False, tags): + nparray for tensor_name, nparray in global_model_dict.items() + } + # Create tensorkeys that should stay local + local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num, False, tags): + nparray for tensor_name, nparray in local_model_dict.items() + } + # The train/validate aggregated function of the next round will look + # for the updated model parameters. + # This ensures they will be resolved locally + next_local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num + 1, False, ('model',)): nparray + for tensor_name, nparray in local_model_dict.items()} + + global_tensor_dict = { + **output_metric_dict, + **global_tensorkey_model_dict + } + local_tensor_dict = { + **local_tensorkey_model_dict, + **next_local_tensorkey_model_dict + } + + return global_tensor_dict, local_tensor_dict + + +def set_pt_model_from_tensor_dict(model, tensor_dict, device, with_opt_vars=False): + """Set the tensor dictionary. + Args: + model: the pytorch nn.module object + tensor_dict: The tensor dictionary + device: the device where the tensor values need to be sent + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + """ + # Sets tensors for model layers and optimizer state. + # FIXME: model.parameters() instead? Unclear if load_state_dict() or + # simple assignment is better + # for now, state dict gives us names, which is good + # FIXME: do both and sanity check each time? + + new_state = {} + # Grabbing keys from model's state_dict helps to confirm we have + # everything + for k in model.state_dict(): + new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device) + + # set model state + model.load_state_dict(new_state) + + if with_opt_vars: + # see if there is state to restore first + if tensor_dict.pop('__opt_state_needed') == 'true': + _set_optimizer_state(model.get_optimizer(), device, tensor_dict) + + # sanity check that we did not record any state that was not used + assert len(tensor_dict) == 0 + + +def _derive_opt_state_dict(opt_state_dict): + """Separate optimizer tensors from the tensor dictionary. + Flattens the optimizer state dict so as to have key, value pairs with + values as numpy arrays. + The keys have sufficient info to restore opt_state_dict using + expand_derived_opt_state_dict. + Args: + opt_state_dict: The optimizer state dictionary + """ + derived_opt_state_dict = {} + + # Determine if state is needed for this optimizer. + if len(opt_state_dict['state']) == 0: + derived_opt_state_dict['__opt_state_needed'] = 'false' + return derived_opt_state_dict + + derived_opt_state_dict['__opt_state_needed'] = 'true' + + # Using one example state key, we collect keys for the corresponding + # dictionary value. + example_state_key = opt_state_dict['param_groups'][0]['params'][0] + example_state_subkeys = set( + opt_state_dict['state'][example_state_key].keys() + ) + + # We assume that the state collected for all params in all param groups is + # the same. + # We also assume that whether or not the associated values to these state + # subkeys is a tensor depends only on the subkey. + # Using assert statements to break the routine if these assumptions are + # incorrect. + for state_key in opt_state_dict['state'].keys(): + assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys()) + for state_subkey in example_state_subkeys: + assert (isinstance( + opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor) + == isinstance( + opt_state_dict['state'][state_key][state_subkey], + pt.Tensor)) + + state_subkeys = list(opt_state_dict['state'][example_state_key].keys()) + + # Tags will record whether the value associated to the subkey is a + # tensor or not. + state_subkey_tags = [] + for state_subkey in state_subkeys: + if isinstance( + opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor + ): + state_subkey_tags.append('istensor') + else: + state_subkey_tags.append('') + state_subkeys_and_tags = list(zip(state_subkeys, state_subkey_tags)) + + # Forming the flattened dict, using a concatenation of group index, + # subindex, tag, and subkey inserted into the flattened dict key - + # needed for reconstruction. + nb_params_per_group = [] + for group_idx, group in enumerate(opt_state_dict['param_groups']): + for idx, param_id in enumerate(group['params']): + for subkey, tag in state_subkeys_and_tags: + if tag == 'istensor': + new_v = opt_state_dict['state'][param_id][ + subkey].cpu().numpy() + else: + new_v = np.array( + [opt_state_dict['state'][param_id][subkey]] + ) + derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v + nb_params_per_group.append(idx + 1) + # group lengths are also helpful for reconstructing + # original opt_state_dict structure + derived_opt_state_dict['__opt_group_lengths'] = np.array( + nb_params_per_group + ) + + return derived_opt_state_dict + + +def expand_derived_opt_state_dict(derived_opt_state_dict, device): + """Expand the optimizer state dictionary. + Takes a derived opt_state_dict and creates an opt_state_dict suitable as + input for load_state_dict for restoring optimizer state. + Reconstructing state_subkeys_and_tags using the example key + prefix, "__opt_state_0_0_", certain to be present. + Args: + derived_opt_state_dict: Optimizer state dictionary + Returns: + dict: Optimizer state dictionary + """ + state_subkeys_and_tags = [] + for key in derived_opt_state_dict: + if key.startswith('__opt_state_0_0_'): + stripped_key = key[16:] + if stripped_key.startswith('istensor_'): + this_tag = 'istensor' + subkey = stripped_key[9:] + else: + this_tag = '' + subkey = stripped_key[1:] + state_subkeys_and_tags.append((subkey, this_tag)) + + opt_state_dict = {'param_groups': [], 'state': {}} + nb_params_per_group = list( + derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32) + ) + + # Construct the expanded dict. + for group_idx, nb_params in enumerate(nb_params_per_group): + these_group_ids = [f'{group_idx}_{idx}' for idx in range(nb_params)] + opt_state_dict['param_groups'].append({'params': these_group_ids}) + for this_id in these_group_ids: + opt_state_dict['state'][this_id] = {} + for subkey, tag in state_subkeys_and_tags: + flat_key = f'__opt_state_{this_id}_{tag}_{subkey}' + if tag == 'istensor': + new_v = pt.from_numpy(derived_opt_state_dict.pop(flat_key)) + else: + # Here (for currrently supported optimizers) the subkey + # should be 'step' and the length of array should be one. + assert subkey == 'step' + assert len(derived_opt_state_dict[flat_key]) == 1 + new_v = int(derived_opt_state_dict.pop(flat_key)) + opt_state_dict['state'][this_id][subkey] = new_v + + # sanity check that we did not miss any optimizer state + assert len(derived_opt_state_dict) == 0 + + return opt_state_dict + + +def _get_optimizer_state(optimizer): + """Return the optimizer state. + Args: + optimizer + """ + opt_state_dict = deepcopy(optimizer.state_dict()) + + # Optimizer state might not have some parts representing frozen parameters + # So we do not synchronize them + param_keys_with_state = set(opt_state_dict['state'].keys()) + for group in opt_state_dict['param_groups']: + local_param_set = set(group['params']) + params_to_sync = local_param_set & param_keys_with_state + group['params'] = sorted(params_to_sync) + + derived_opt_state_dict = _derive_opt_state_dict(opt_state_dict) + + return derived_opt_state_dict + + +def _set_optimizer_state(optimizer, device, derived_opt_state_dict): + """Set the optimizer state. + Args: + optimizer: + device: + derived_opt_state_dict: + """ + temp_state_dict = expand_derived_opt_state_dict( + derived_opt_state_dict, device) + + # FIXME: Figure out whether or not this breaks learning rate + # scheduling and the like. + # Setting default values. + # All optimizer.defaults are considered as not changing over course of + # training. + for group in temp_state_dict['param_groups']: + for k, v in optimizer.defaults.items(): + group[k] = v + + optimizer.load_state_dict(temp_state_dict) + + +def to_cpu_numpy(state): + """Send data to CPU as Numpy array. + Args: + state + """ + # deep copy so as to decouple from active model + state = deepcopy(state) + + for k, v in state.items(): + # When restoring, we currently assume all values are tensors. + if not pt.is_tensor(v): + raise ValueError('We do not currently support non-tensors ' + 'coming from model.state_dict()') + # get as a numpy array, making sure is on cpu + state[k] = v.cpu().numpy() + return state diff --git a/examples/fl/project/mlcube.py b/examples/fl/project/mlcube.py new file mode 100644 index 000000000..6ee9e2de3 --- /dev/null +++ b/examples/fl/project/mlcube.py @@ -0,0 +1,34 @@ +"""MLCube handler file""" +import argparse +from collaborator import start_collaborator +from aggregator import start_aggregator + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + train = subparsers.add_parser("train") + train.add_argument("--data_path", metavar="", type=str, required=True) + train.add_argument("--labels_path", metavar="", type=str, required=True) + train.add_argument("--parameters_file", metavar="", type=str, required=True) + train.add_argument("--node_cert_folder", metavar="", type=str, required=True) + train.add_argument("--ca_cert_folder", metavar="", type=str, required=True) + train.add_argument("--network_config", metavar="", type=str, required=True) + train.add_argument("--output_logs", metavar="", type=str, required=True) + + agg = subparsers.add_parser("start_aggregator") + agg.add_argument("--input_weights", metavar="", type=str, required=True) + agg.add_argument("--parameters_file", metavar="", type=str, required=True) + agg.add_argument("--node_cert_folder", metavar="", type=str, required=True) + agg.add_argument("--ca_cert_folder", metavar="", type=str, required=True) + agg.add_argument("--output_logs", metavar="", type=str, required=True) + agg.add_argument("--output_weights", metavar="", type=str, required=True) + agg.add_argument("--network_config", metavar="", type=str, required=True) + agg.add_argument("--collaborators", metavar="", type=str, required=True) + + args = parser.parse_args() + if hasattr(args, "data_path"): + start_collaborator(**vars(args)) + else: + start_aggregator(**vars(args)) diff --git a/examples/fl/project/utils.py b/examples/fl/project/utils.py new file mode 100644 index 000000000..52c1eeeca --- /dev/null +++ b/examples/fl/project/utils.py @@ -0,0 +1,164 @@ +import yaml +import os +import pandas as pd + +WORKSPACE = os.environ["FL_WORKSPACE"] + + +def get_aggregator_fqdn(): + plan_path = os.path.join(WORKSPACE, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return plan["network"]["settings"]["agg_addr"].lower() + + +def get_collaborator_cn(): + # TODO: check if there is a way this can cause a collision/race condition + # TODO: from inside the file + return os.environ["COLLABORATOR_CN"] + + +def get_weights_path(): + plan_path = os.path.join(WORKSPACE, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return { + "init": plan["aggregator"]["settings"]["init_state_path"], + "best": plan["aggregator"]["settings"]["best_state_path"], + "last": plan["aggregator"]["settings"]["last_state_path"], + } + + +def prepare_plan(parameters_file, network_config): + with open(parameters_file) as f: + params = yaml.safe_load(f) + if "plan" not in params: + raise RuntimeError("Parameters file should contain a `plan` entry") + with open(network_config) as f: + network_config_dict = yaml.safe_load(f) + plan = params["plan"] + plan["network"]["settings"].update(network_config_dict) + target_plan_folder = os.path.join(WORKSPACE, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") + with open(target_plan_file, "w") as f: + yaml.dump(plan, f) + + +def prepare_cols_list(collaborators_file): + with open(collaborators_file) as f: + cols = f.read().strip().split("\n") + + target_plan_folder = os.path.join(WORKSPACE, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + target_plan_file = os.path.join(target_plan_folder, "cols.yaml") + with open(target_plan_file, "w") as f: + yaml.dump({"collaborators": cols}, f) + + +def prepare_init_weights(input_weights): + error_msg = f"{input_weights} should contain only one file: *.pbuf" + + files = os.listdir(input_weights) + file = files[0] + if len(files) != 1 or not file.endswith(".pbuf"): + raise RuntimeError(error_msg) + + file = os.path.join(input_weights, file) + + target_weights_subpath = get_weights_path()["init"] + target_weights_path = os.path.join(WORKSPACE, target_weights_subpath) + target_weights_folder = os.path.dirname(target_weights_path) + os.makedirs(target_weights_folder, exist_ok=True) + os.symlink(file, target_weights_path) + + +def prepare_node_cert(node_cert_folder, target_cert_folder_name, target_cert_name): + error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" + + files = os.listdir(node_cert_folder) + file_extensions = [file.split(".")[-1] for file in files] + if len(files) != 2 or sorted(file_extensions) != ["crt", "key"]: + raise RuntimeError(error_msg) + + if files[0].endswith(".crt") and files[1].endswith(".key"): + cert_file = files[0] + key_file = files[1] + else: + key_file = files[0] + cert_file = files[1] + + key_file = os.path.join(node_cert_folder, key_file) + cert_file = os.path.join(node_cert_folder, cert_file) + + target_cert_folder = os.path.join(WORKSPACE, "cert", target_cert_folder_name) + os.makedirs(target_cert_folder, exist_ok=True) + target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") + target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") + + os.symlink(key_file, target_key_file) + os.symlink(cert_file, target_cert_file) + + +def prepare_ca_cert(ca_cert_folder): + error_msg = f"{ca_cert_folder} should contain only one file: *.crt" + + files = os.listdir(ca_cert_folder) + file = files[0] + if len(files) != 1 or not file.endswith(".crt"): + raise RuntimeError(error_msg) + + file = os.path.join(ca_cert_folder, file) + + target_ca_cert_folder = os.path.join(WORKSPACE, "cert") + os.makedirs(target_ca_cert_folder, exist_ok=True) + target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") + + os.symlink(file, target_ca_cert_file) + + +def __modify_df(df): + # gandlf convention: labels columns could be "target", "label", "mask" + # subject id column is subjectid. data columns are Channel_0. + # Others could be scalars. # TODO + labels_columns = ["target", "label", "mask"] + data_columns = ["channel_0"] + subject_id_column = "subjectid" + for column in df.columns: + if column.lower() == subject_id_column: + continue + if column.lower() in labels_columns: + prepend_str = "labels/" + elif column.lower() in data_columns: + prepend_str = "data/" + else: + continue + + df[column] = prepend_str + df[column].astype(str) + + +def prepare_data(data_path, labels_path, cn): + target_data_folder = os.path.join(WORKSPACE, "data", cn) + os.makedirs(target_data_folder, exist_ok=True) + target_data_data_folder = os.path.join(target_data_folder, "data") + target_data_labels_folder = os.path.join(target_data_folder, "labels") + target_train_csv = os.path.join(target_data_folder, "train.csv") + target_valid_csv = os.path.join(target_data_folder, "valid.csv") + + os.symlink(data_path, target_data_data_folder) + os.symlink(labels_path, target_data_labels_folder) + train_csv = os.path.join(data_path, "train.csv") + valid_csv = os.path.join(data_path, "valid.csv") + + train_df = pd.read_csv(train_csv) + __modify_df(train_df) + train_df.to_csv(target_train_csv, index=False) + + valid_df = pd.read_csv(valid_csv) + __modify_df(valid_df) + valid_df.to_csv(target_valid_csv, index=False) + + data_config = f"{cn},data/{cn}" + plan_folder = os.path.join(WORKSPACE, "plan") + os.makedirs(plan_folder, exist_ok=True) + data_config_path = os.path.join(plan_folder, "data.yaml") + with open(data_config_path, "w") as f: + f.write(data_config) diff --git a/mock_tokens/generate_tokens.py b/mock_tokens/generate_tokens.py index b202c9eeb..9579b0711 100644 --- a/mock_tokens/generate_tokens.py +++ b/mock_tokens/generate_tokens.py @@ -23,7 +23,15 @@ def token_payload(user): } -users = ["testadmin", "testbo", "testmo", "testdo"] +users = [ + "testadmin", + "benchmarkowner", + "modelowner", + "aggowner", + "traincol1", + "traincol2", + "testcol", +] tokens = {} # Use headers when verifying tokens using json web keys @@ -34,4 +42,4 @@ def token_payload(user): token_payload(user), private_key, algorithm="RS256" ) -json.dump(tokens, open("tokens.json", "w")) +json.dump(tokens, open("tokens2.json", "w")) diff --git a/mock_tokens/tokens.json b/mock_tokens/tokens.json index 6c6bcfc12..532c3e600 100644 --- a/mock_tokens/tokens.json +++ b/mock_tokens/tokens.json @@ -1,6 +1,12 @@ { - "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.iZwlrNHjT90aZt_puWQnNke-7IrtQQ5FXsxpGfrYQjRGGXG4mgAvqI-9o-D4MWw8zdO0pDNddbQI44aoXDa_oOUpo23qhqjo-AahIKKUGu4W166cV6G8lseza7xr7WtZqEn_WA2qJR-IcqZvu80Lt6nURR-7tl80cLK4NdD5TmOvTOZdn4psgQg1uWrfWCLcQvjvfEtGPxHij1zu2usv5FuyDytp49xjFbH90bnepkIV0Jr_BfUZEm75sRf1wfj8c-t3IhqdWySfR0gSC4UW9ieaG_h7_kxRI_J3qfUwBklbtCMkOnApA4FaRUnv48fRBWCGxtU_1AVHbUwwPldMfUd8cDf_76Ipi31nIX5PVw7g7O00L23-CyjGf23U2j4Srz1xBHG_u3HAoT7XXOPpjaLGI0y021e9x7i1GWHMzqzcGcNUlJj8GMfocTJOLR1y4UNYvvWFuhaeqOHpVclcJ22Mo9JjsFLfy5D5TPetk2vBD0bExCgAOmAmhdnSEY96OxItjWYfSlZuBen29JD9NaUCwK4knQm2NODnKeTIS034EQsWGqXT-84VUdR_pJr_-seNzrmLxD9pLNfW3XARgscE-7Rfg29cdADj0RU_KNblkyNjwJzn1XxUI54IRlj3oYTQ8R4VzFag2NbJgqDpg6EQYg8Ii2X1v-8QN3tvuQQ", - "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.BPDuCCa5akAzJmM7qSuUB7A0fmSkGGOKl4bCdRkyeYBCYuPofLHXCeEuILw9iZxKIKGpxiOal-6JLWlo6cteW4-ebcby3u7zV2rc0pWWvsezsFuYQt8FojAi6Knv5R0BbgSLtVxc_BGR7vK0apVEc9VU4ootfDitdHDHFGo2QO90RIIIg1toWIK5NkK39WQLvvUnEXhtvrhejeqFPpJj7SgrggytoW1ZZqGmtFDKhJii1cKdW5anNOsUdD5L386lgh5K5n_nxim63MrZI8wdJvLW16_NcvVRYOrgfEP7jp1kyb5Vmv_NQaS9CsnSMewv-JA4lP9LC3bs1YixOEcHYP6T9z1g6hhV96RkpCjIZuo5QbBYKVsONscePZDLTdlj2NrfMuyjt7NbrJWJpmaOCmqKvnQDIo1gdDd940_kBgcLgrjtTn6LfndXsAWM6U6_x78uKL73XoJwQwPXVwF2_hyo6vEHufx9rfo7WyPD5vTHb_1FrZQE-FZ_0DmOpe4tVaDy9nkfcZzu_jlYMnFMTJSn8te3hs9JWyfDGTwLY9r0WW5htRAS8dHSOGRzXpMCD6gHvcbYEpcEewGGJiaWombVUOCJg-x1Ax8zXx2-k4vFol0_7jYPw0EsduqTtjXChdw8NGDVUKLjjLlAZ_oUAYRCaYOdBpBNFFKKIfJ6oGw", - "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.ZDTNB9UAMnTA67QQXBBqPBjG96QJbFfPweL0AEioSMu-VzZTQlSdmgmXR36Li8opkBwxaVpWKCyP95eCgEAZ6Fdr_kE259kHJxHTD8UrejN8vj0ramqmXfd2xIxEN7q8YGzt4USGXFVII_iCehPandtuXfsqvQrcuu-dlfbsb85Azphi_6SWM3-2U5Vit7bDOeP218XHBy78uhW9mLZt6d838Hk97U9tr2vlhrO7bCOhy5SHSP4svhkF-wIglxarIxPu7RUsYshJp21aY976tJ39_RDS9ResIytYZGrAUacKGU5OJihyaaR_WNoppXGPxJsfbGyqylj_XC9_jgsnZCyfztahzOeWyCjjQeosSiB6dc7cmzlgeDcBtXAvkEtYg6B8SFR-c-NM6Wy0Wd4L4UQW175ySCBRwfuGUlKTszwG7LyRcxolui7ESc8E1PynczfOXxdXut1mxtrXrfKX90jgIV_wVR9LGYF0IVteJS-kudpabyDG65-LAli15ZwYbNYDaFLSit6j4W_sVeN9zZPA6cm37hfzI6pqf1J42R6_hmL1lmdv2Aq_6kumDShxVDugpjpnCt6vaZFphqJR5F_MIW587brzrbVVVNuaD1T1Gf0WJgX9-FaFsvRwXMhEdTfyXDPcnIdYH83T-tNNFyZCJek64-e--WSnSrRrQYs", - "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDEwMzY3NjEsImV4cCI6MTE3MDEwMzY3NjF9.A-6hIS7Ua0nXr89Percg0YVjYmWFcpQSK66URRAjByTuLnvfOokfEyTULpdjB9uNOSEwIOoH55Y3o2lNG5_M9PVshScnPXR8YR0Ow7elU0GIfdcv3YDYu-YzH0iTbrVEc0X91J_SPp2vN5dSO0UVLHE9fQxxL1OzYeu2w9yHTu5OdEpJC7yPLPUu3dbKhknhwO4OGt_9hBD4b8X85f8js89d9YA9gl7XJydwNjdnNfv8ZUXwDP9b77h0RBeGGNkyYNNkRAex35g5D4Xd9evPsf7obHZkRTjiF-RMToCeJNnjv0BzHlsiz66qAONjW6LCcAwW7-ACS282YYP1yRhr4NXQinfBBo62auYa25sRTHv6uRV7IMrw01OovbVb4lbuKWTZQ0TUmW_UlQujv_EpzoXHcC__ZvMVRG9einjLdj2MfUHuvmxLeM2OYK75VzjP63YKP_67O9hAbAzyn1Z6-CQAgom8coGUV8Mdwz-L-VmvpYMdmCFohYRTqspluj-O-hrMNKexycdbbuflodkRhn8Pi35fUqb2UP5eQvhXusX7ob-YB7PHaka6xY8Wsndb_blEoTkSUmbxeSWt8QKz_n6AWBO6LCgOCG1LyAv2_DaL2M00sJ9oLCGDYnkDRoc_Aq45GEHxIS8cEQL2siFQbrH-KPT8PlKshNjMLtyNcO8" + "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTY2MDk3NzAsImV4cCI6MTE2OTY2MDk3NzB9.oKS-F8L6Vv70vKklpX2kBiBkkKCh9D-RtxdVzSAiNk9MTDIJ-_7wNu0yBiEnX38IOubnifkh8v5OyAJ85dU5Au1LsekI0YyTI0WLQXKbywP89vfYZlfEIACvPUWJhRbHMJOGn-WVrPuEbGMDuDw677xOm5T04Hol9Qg4rAsNjYt05SnVwM4ico2CH9AR0LSrsdC_QCpLVvym9ewE1CrstmalPWWM3SeBves4qGSlwl1oTXgoOUXgK9DwxLHB-r66XZrcNwXZRuBYSTeqQvDnGP8TG4bXL1gQbvkh2tDbtj-DEyfUXxPxN_GVnlqk6I8BS-A9IoWiKdf0rYatelHd7aWbBgdCg6fxJ4HL4vxChqi3-X6dH2O4vUGkTCR0Td5NDhhe4gfj8WxXD293i0Glu1xOO8DVnu4j6GDfK8WfXtUgHwc4FJHb6iXDJqhAnP2jy_LPSEfjnItKQNvRyu8W3D6LxcumCLO6IvhFkOFDzcjkkyEIdLEiUzDkNeiM5eCqmuaoYW89ARy_GTMKMaLKRWyeGzeNP6celYJtoVzaUOeQvlD7mTFuRupiVC5PxuVFysdTUqU97MSRGSrjqxpghKcEHzVH_mJypWM_Psv4v_6zUhzBV7VKPfGTyareK7Wl6aWPKrGYVm23aAWAjcMgd6otQpjKY6SIkLCIRASuZlE", + "benchmarkowner@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoiYmVuY2htYXJrb3duZXJAZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6ImJlbmNobWFya293bmVyIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNjk2NjA5NzcwLCJleHAiOjExNjk2NjA5NzcwfQ.OPgLiySauDxTtjjAEvcOoklpquVOsu95XBECxDvn1v1h1NW__h7ob21bytCJic34wfF_b2Zy9FOXF8kTKUwmVqwA3lYn9i35R9LegGPyEqNMOQ89ou7xZFJvxUJYm92fz9R0oIh99swyCuAqaze-B5I8B8lH8wUBPlsaV6-EV-O2VOPmACQamLJzmuLhKg5P9cTP5dngEytHK6AFzIiuDYk_7JKfy-DCsxPRnpQV8Ct_dZgwV3RxTtGaFKktttrgTOo7DRPjXw5q-BzLX8RQYL8W5Y6taQ4qPVz5Q32EYKGzvCJtNj-gEFE5p30kfv1DMKyHwe9WnYgGG1EMREnyUWeHNQWG_SUfk07sE_RJpQ5FgvQFT0PGP0vKX16YUy5CqXQflRN45lv9qBFUEii1v4ORDJbYtrgaTEQeOLaL2l_7ucnNQsgeKhkiXLUdWedueUv1Zjd_Y5yL0isG6pUQeDJXtt9FugZMkzrWX7yqrRmqDy4oX0DMobCuLXy858MWgjwd1xEq0vN174OFjW1d2fJ9unNlx_A9ZRAPIV9L3ZtZ91yqlKfkym_8WTzcUyHzeeirQ9Gq6_vVYHIBcALJ2_EQWSWqbZXJtTbUMxT4IIGWZtXaDOMr23HsxtcuMQkjFVFJ0BW0SNpUByc1mIis99RtYHbbSPzpuWHhEKhhfm0", + "modelowner@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoibW9kZWxvd25lckBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoibW9kZWxvd25lciIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTY5NjYwOTc3MCwiZXhwIjoxMTY5NjYwOTc3MH0.YKGELPYlzHjT1M5C83xo6wLSMQ1Cj5O9AixoUUikc4PYXbvaaz6kZ92hc1vD9Oy9LTp85JKMkiwHkkmenkkYrEvkZzc4WTYaKOqGbn6fKpErBnSgszhbPnh6oI3hErSYI2hAI42v0w2H39vY4Kj6dZxo-1grWZG4D_o1xcc9OM4BAr9cD2GEQVrtURTbF2j46gz4uZZHWZZEyRCJDFfpKq9X0EdhR8muFvKQZ99Jfp-omM9vHZ7E4Bj9W4K15xodhKVzDwlFVBK-oXkvPo08-vMlWMQvXxB3dKBkPQMjYUOAstcdi4D6mEv9MfDKwxXIY_dKsxINkReU-6CdSRDO1mmc2SJ3k362Bd8r_Aq3T9P57VvsxyUxdD8RzOfuJk2letHSbhkJ4XuJeARDaF64oygk3-jLNhuZ4LEDMG00BqstIyEH8WlhKGhDU6AK0GrHytSp0NeeMyAXPGJg-OJn29eV1SW2N4UzOQ2Fnjj9klr07zM1U_vc68P4pKIwvnkoWVeNBjrU5sVugqziVX_BnbNHaOYfWbNlIZ4ngkjapr9Xr8WJd5yC2bcp90hgy9cbaLbjEKmF4mxnzE8IEkjXMzGDAb9oWpYSUnU0U3dxaKgdtDB8-mmjlcYpojcN7Iby2GLv8DSTuW9iojxcKw3YLoKONCfEVKI0ssc_R4zb9dI", + "aggowner@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoiYWdnb3duZXJAZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6ImFnZ293bmVyIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNjk2NjA5NzcwLCJleHAiOjExNjk2NjA5NzcwfQ.E2UH220kMwuow7bfKt13DCMAzY6i0YO7AciNqGXl9ApWT6kJiYC_swgvzkfDd509YEEcEeiVS8Ik6xfWCblqXVPkfdYLm1MmzHbnbDODIfeTd9uPAmIrBkDYY3vnQtpcg7NJS24VxO-2xa35YY6A8FsWGD3QqnJd9tuken5RQ4OAZTPaqvsPiYDLXfYwieM3bVjM7o7GsmykOiPN8E0-4qH78COWBJU9izymluJJhEetW8CdCxmJ7PomXQXvWrCoQjCf3J_8i28TABIianjHytdFSpXxqE14IVwn-OU7qj6V6zmjr5gfSmocNVcb3kOrwb_QJv3Cqww7tUJwF0q5_EncCtVB0XtZkDFRCBzQynI8-VR-z4eZ6SpK4RdYsBYHNydGh88RfKvcszCceAh7MGhh1XEjaKgM_IQAjnrXlMv2dRmvuSxooXXVhoZ1lx6tq9Sfvf4FvuOJCAVicPfaiMpoZW0UblEQn6zhugu3hiu6huDaYd5Opx9ZyRaq_fH_SNwnd9GvioqNNgOyatrCgqFSJdMQlMfHv_lXl0tgR60wFpWuTAZSz68bMvivwACunN3C8XwER1i-SHBO6OwoKpXMNW1bOaWPo9niggY1LdQuBeLauKG35Ee7rQZ8sVjKadu58It6GHBEQG_jLJN4puFmSHH6eMhwGQKjaVuOKS8", + "traincol1@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidHJhaW5jb2wxQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDEiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTY2MDk3NzAsImV4cCI6MTE2OTY2MDk3NzB9.lz-yInC8p4ENvIfWghre7B9UuAZKP5yQaUc6V2-0QsgP5fq3HnUpd5zl9DaWLDi0WX7OOk9CbY5aOFKgOZ1g4ulSh0QX382R3QpUMmKbvaZrVVKaaN1AaNmjv_89bKNfqOboz4z3hj_9S5F2I2aNy-LYYKzxa9cbMyHCieaA9-KZdZRVV-tK1co-nxnVU0QlKAf6TQwVEaBfBPWapFwnyqSs9v7M6WixKmyr0zWjHupKObmncpVkKmIh3HjTXEtdxeC4F4-V7xPJBHoK_hbTcnzySzQYmZWzgHiTWK_lM-U9Y3ugcPDRoVnIoL1_tJdFrBtsqgTqYvgGAW_1gwd94eWvwCUzpNBjr910byRPcFHdlZX11vnmhESVdZV62wQeOuwacR07FSWGxYHgAM_UWhySOcoB4qtza14tTb8YrlIZVUMYxoQYlB02RY5ZXvV00RqKJWJWKdQYCWgheeuZulTRC1Y4V9eu8GRZsfnk4omWVmkkcmx_Q3JHuEiGFfZmTZTF9p0m2tVZBWE4ML6EnX7ndJ8scZnsqZGkI8WsEabaPmcjg6elRGiqIMsOj7oigEiYqKnxiWQ3bAlXUbXAMYX_kAb9y3PVoiMsXpT9cFkTbOkFTU0_aYlcUOsIAZLhu8dMw29fJ6r8ZMgblOwo2jsgqTrAUqeBM85YXSz6NAM", + "traincol2@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidHJhaW5jb2wyQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDIiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTY2MDk3NzAsImV4cCI6MTE2OTY2MDk3NzB9.PRs-r2Iyye63fWyplOhGCQhKNHRMCup2upgpIWmh-xXAV1cBT6Wq9ANYq7oUmq03G8Aw53ZX0H8wV67nIhkD_dG8gEV_Hp_36rv8EmPUdlBYLQP0rV059zZP7s594sXtZ1G1hFZvs18XY_GNJuyOpZ-GdWH2nPjAvyGfU8JBYVDjf79HgJbrbfDLBvlRmrrCCA40bO6ScNrTXPnsBuefEuLEGqWsBiVKU8hOoBnPj8NAanxQBjgpph7kkPU7kmaxHn9rJEp3-S_8Ozi3J635roOKgumysDDrwcDt7oPJLrL6SVhHWzmMzBxN-ozdXAI9sJw1H4_bw6CG3MoahNJxwr8kHdFh8GqA58K_aVOAZNeSa_EUJGjquVkidOKyKcfcammHgw6cUmMk2Y1GepqSr4-KRjLrewkC3jxdnCeWgqPoiUrpCc8OcRGXMhyiqvYJUFdnfqEOMvXN5KCo5KeGew20h3zLiJaacDNQDufjLycq7x2DbS5UmtWXgAbGQWKllaVIwrIOAJ-Ev8HKXEXsYcPUqrTel481ptzChZ0Co9HokXGhHf9R0s04zwc3c-jQUqmcMa3nYSuEoInQdgjP9xgP2jXTyfwj-Pf5BZ8Tisurt_PWugcheKWXBL4txVoulfsN1CEDR7sXeWIbSVhRXX15EuASETUMMEzlI5WaXM0", + "testcol@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGNvbEBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGNvbCIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTY5NjYwOTc3MCwiZXhwIjoxMTY5NjYwOTc3MH0.XEj_ft8hxKm1dphiBhivx-SWG2BP39k1cMkArDHviKtp1vlROUUN0GgE_pkNfyn_6Rx9JUCp6LyDbCzCSC0ZlZ47qnuCy-rfVPUNRhKE_UYzFHUg8joaCJ6_3o7gzAoJLoHl6oc5bL3LcQIPkIqCFyoBp3JUaT0Arrv_tqSw6Y8WyQVxfkmjcnRIKxpCvUPW0SfVq605HigyULvfWyrJSepAn_Mw8bfDGEgCE3DfkODI2t0qpGuM-m-0neA99jk13VJdtEYVU2c_rbiGc0W8fBvkwytlqCit2JVDT6SlLgkJHr5WJQkP4Jx5AJ21bdTcWVMJu-xInO8yXSvLG_u0rRPl5oNImYXlO06c0PYCXxTuqqbljm_Knj48YcoBWrgWuxB-geH_LCXPxT7O1pQoosodBPwNHwlMt3C84rcUXeOxPTErcDputY-UNnuKxK0XNGMpsXnJZiQRh3I_4v5FtHE01DY8w9XjrMxLjBnUuISkz-Ct195wG-Od4v7Sw_7ikey0fn6f0EPo4ETmM1q5Oa_1RHV64GWG5NtaDezZbqDYQYCgFxm1coUIbPFwSWv7gT7w15Mj1kYO-RoLmt7DtHmTLu9ilebNhfUV34ZABnt96RFo_bSJY_t2eeWYSwRlnkCDbktmtBKlG5o8LKQVJq6Uki-90H2-jyIG1dp5BRg", + "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.c5dnk-e8ys8MpaRzBkFe_RQTjGLKQlGSkJVNBhiKbzJV6565iHQ1HvFp88q-6ldGIwVWw9ZyY6QH2EWbPYfLF69-KnNcxUbDOJe9jBX0UAfaUsrcsaVoxLJPCojnjqKoIKgu_NM5PlEvsn4ojYA8Q-DxJ7r9RexnhxG0TxU_CwjhYuzV9RQsE3phWbHmcFv1-OIGWr72q1p8QwfaQp42K4iyUC7u4Buk6we9V4NJIJPPjjadkmdsYnm3pJxYVhKZx1pTkGnHTt6YEilQM9Iwgw_1mA5o1AYkwTm-_9lMxbSiwMzdxZmM1S_L-XoVlrWmyxeu2-BdLMIt49LS3fSXXDkcGyLboxLlg8v5rCMkhpUvDvw-dwpVUi1Y6QIVIqQocDJ3Bj1K-5SbGISc2wU7Aa9GNXe7GcEFn9DoeCthy3aLfucc3l8usZagopkAjVSGmClSJNw90VoWA7kER31E7Ehas5mnlJedGeyNV1wm5r2sMJfUsnoVqaCMHNFE4SQJaMVkGQOS4gn-G_8_WMLVrvfT1f_dg-dvISWrvhZ7-Uow-gi9-px2oY3Ehk2b82CPRMf3O20HOppLTg7ETqm77wQq9Elqrn2KoAYgc5Rlr-JZEq76NI37AgYRDIRistLI9_-UzJmY-YYbep2PsiWBx8WrUqm8uDIKDoGhfjBElQo", + "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.FOKlc3ZKH193-cINs9fxIPitsPeE2RZYl_H96xWaKIGOYOec55hR1WgybXD47czGUuTLS8NQGnFcTf-CGABWhFs7f3oTWKxpqo98sOGIhXxASgktfKmqL8tZCv-a4LCUoGFfeekWCvSo_fezkFnxZT1090JpPYpwcbjbTPY4u0JSU4XbaQh4y7N0lJXlmaZuAb13ncur7uY4A6Onl3u49m3tRjbg36r8fyzUwKppMyo5JdLS4h8099ZW7B6v32Xsr4TEz-UOA88YF_mmgK1P9BaFL7aQuARqsqvIQgdmH4JYxnoyng-UiroANy_HeeMJjhDtkuAVUsF6UOG_excgC9jA53GKNCSxxqERD-mxvZ5juWyL86fTSwuaBALWVv5HFUjHbxsA-2JzKt-ZrSD93Q-DqTi36f1WGfwphypc-2W201d4l1EJSg2b4FAqqt2dDDV4deGg3FCkR4vHXdA5aWIihZWmtSLsCylkg9bJiR0iZp6w-6I4R-8KDTG10UgFLPQDS95yLhClcjUjFnmI-xs5rLJvfuIbQU58yoOlBNwVPEm-G7vfTpw7kl_x3qjV_mOw2QSJE0_iQ_gi_YZ1Lzh4ueEgBFYF7VPrLV4x2djJ-Yd1Qb--YQaBRTTFpZ1A8EqUQ-BCn0K9FoWgT2ChZj-B_Oks6z5Z2sqXyauYFU8", + "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.PPfm_uOpz0SuAUEVKE-630mSDuboRUUliSoL5CiE2gUmzxHW5iyJ-37pOfuO5e6y0A4YMnh_xTv3yxgnG5OOB_6ZLD_KF6e5BRJWkZhqKixGLDpQReUv9PadR02eMgtaQBKNuy9Ey8EV8mtyeIUH1JDoIwO4Lt17XHzvYwM7JdCH3TguxPYvCMORoLwKFfhPUlBeMDnjQBnOiw7gv46CdoXHAUqj7k87gp04opgSaeA1tsJIzakZiwXLK4CJvTo_jIhd_w4RNXYHiacI5AtoI8zJKzIdmptkJKcJv2yWq8nFpBlpAr09X-c4haBB4xGeX4743yOBQ6jMdW2f6mqOPTypbEn94_tbn-HaqbZJoclSwByDX0AN3j5KZ0-W-zR3CkPazkCqIWNLO9dte1fz2iocmsAmYdxmvcCN-uAMFI1tRaYHHTd3lyv-GOOllhs-Pc-H5fBfuN-H-l1LejdtdIN9qAxx5BgGVla_ac1mdHEqljO4AuppK4dMcSzw3BUHU4R0uqZ81-a-RoahG8SiPvihjl8foWCaB86qCBJDGItJno2Zdnw4Qpk-EeVSCnQAz0PzVhw9hhhNlAeO8GvKL1-OR1aiXn126empXYM7-wDjs8aAEY4oCo5987GU-pbdVwE8As2FCFI6vbW2B3JBmKideidFpeT92JWqeXv8Q_M" } \ No newline at end of file diff --git a/server/aggregator/__init__.py b/server/aggregator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator/admin.py b/server/aggregator/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/aggregator/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/aggregator/apps.py b/server/aggregator/apps.py new file mode 100644 index 000000000..4fa2bf6ac --- /dev/null +++ b/server/aggregator/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AggregatorConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'aggregator' diff --git a/server/aggregator/migrations/0001_initial.py b/server/aggregator/migrations/0001_initial.py new file mode 100644 index 000000000..e5548f0fb --- /dev/null +++ b/server/aggregator/migrations/0001_initial.py @@ -0,0 +1,31 @@ +# Generated by Django 3.2.20 on 2023-09-29 01:02 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='Aggregator', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=20, unique=True)), + ('server_config', models.JSONField(blank=True, default=dict, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('modified_at', models.DateTimeField(auto_now=True)), + ('owner', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['created_at'], + }, + ), + ] diff --git a/server/aggregator/migrations/__init__.py b/server/aggregator/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator/models.py b/server/aggregator/models.py new file mode 100644 index 000000000..96efd6485 --- /dev/null +++ b/server/aggregator/models.py @@ -0,0 +1,18 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class Aggregator(models.Model): + owner = models.ForeignKey(User, on_delete=models.PROTECT) + name = models.CharField(max_length=20, unique=True) + server_config = models.JSONField(default=dict, blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return self.server_config + + class Meta: + ordering = ["created_at"] diff --git a/server/aggregator/serializers.py b/server/aggregator/serializers.py new file mode 100644 index 000000000..acfd0726f --- /dev/null +++ b/server/aggregator/serializers.py @@ -0,0 +1,9 @@ +from rest_framework import serializers +from .models import Aggregator + + +class AggregatorSerializer(serializers.ModelSerializer): + class Meta: + model = Aggregator + fields = "__all__" + read_only_fields = ["owner"] \ No newline at end of file diff --git a/server/aggregator/urls.py b/server/aggregator/urls.py new file mode 100644 index 000000000..0c86c9197 --- /dev/null +++ b/server/aggregator/urls.py @@ -0,0 +1,12 @@ +from django.urls import path +from . import views +from aggregator_association import views as aviews + +app_name = "aggregator" + +urlpatterns = [ + path("", views.AggregatorList.as_view()), + path("/", views.AggregatorDetail.as_view()), + path("training_experiments/", aviews.ExperimentAggregatorList.as_view()), + path("/training_experiments//", aviews.AggregatorApproval.as_view()), +] \ No newline at end of file diff --git a/server/aggregator/views.py b/server/aggregator/views.py new file mode 100644 index 000000000..940d87252 --- /dev/null +++ b/server/aggregator/views.py @@ -0,0 +1,52 @@ +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .models import Aggregator +from .serializers import AggregatorSerializer +from drf_spectacular.utils import extend_schema + + +class AggregatorList(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + @extend_schema(operation_id="aggregators_retrieve_all") + def get(self, request, format=None): + """ + List all aggregators + """ + aggregators = Aggregator.objects.all() + aggregators = self.paginate_queryset(aggregators) + serializer = AggregatorSerializer(aggregators, many=True) + return self.get_paginated_response(serializer.data) + + def post(self, request, format=None): + """ + Create a new Aggregator + """ + serializer = AggregatorSerializer(data=request.data) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class AggregatorDetail(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + def get_object(self, pk): + try: + return Aggregator.objects.get(pk=pk) + except Aggregator.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve an aggregator instance. + """ + aggregator = self.get_object(pk) + serializer = AggregatorSerializer(aggregator) + return Response(serializer.data) diff --git a/server/aggregator_association/__init__.py b/server/aggregator_association/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator_association/admin.py b/server/aggregator_association/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/aggregator_association/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/aggregator_association/apps.py b/server/aggregator_association/apps.py new file mode 100644 index 000000000..4df2c898d --- /dev/null +++ b/server/aggregator_association/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class AggregatorAssociationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'aggregator_association' diff --git a/server/aggregator_association/migrations/0001_initial.py b/server/aggregator_association/migrations/0001_initial.py new file mode 100644 index 000000000..f6f0825e5 --- /dev/null +++ b/server/aggregator_association/migrations/0001_initial.py @@ -0,0 +1,36 @@ +# Generated by Django 3.2.20 on 2023-09-29 01:02 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('aggregator', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='ExperimentAggregator', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('certificate', models.TextField(blank=True)), + ('signing_request', models.TextField()), + ('metadata', models.JSONField(default=dict)), + ('approval_status', models.CharField(choices=[('PENDING', 'PENDING'), ('APPROVED', 'APPROVED'), ('REJECTED', 'REJECTED')], default='PENDING', max_length=100)), + ('approved_at', models.DateTimeField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('modified_at', models.DateTimeField(auto_now=True)), + ('aggregator', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to='aggregator.aggregator')), + ('initiated_by', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['created_at'], + }, + ), + ] diff --git a/server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py b/server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py new file mode 100644 index 000000000..536303150 --- /dev/null +++ b/server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py @@ -0,0 +1,22 @@ +# Generated by Django 3.2.20 on 2023-09-29 01:02 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('aggregator_association', '0001_initial'), + ('training', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='experimentaggregator', + name='training_exp', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='training.trainingexperiment'), + ), + ] diff --git a/server/aggregator_association/migrations/__init__.py b/server/aggregator_association/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/aggregator_association/models.py b/server/aggregator_association/models.py new file mode 100644 index 000000000..0b069ca96 --- /dev/null +++ b/server/aggregator_association/models.py @@ -0,0 +1,29 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class ExperimentAggregator(models.Model): + MODEL_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + certificate = models.TextField(blank=True) + signing_request = models.TextField() + aggregator = models.ForeignKey("aggregator.Aggregator", on_delete=models.PROTECT) + training_exp = models.ForeignKey( + "training.TrainingExperiment", on_delete=models.CASCADE + ) + initiated_by = models.ForeignKey(User, on_delete=models.PROTECT) + metadata = models.JSONField(default=dict) + approval_status = models.CharField( + choices=MODEL_STATUS, max_length=100, default="PENDING" + ) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/aggregator_association/permissions.py b/server/aggregator_association/permissions.py new file mode 100644 index 000000000..a4723b0fe --- /dev/null +++ b/server/aggregator_association/permissions.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import BasePermission +from training.models import TrainingExperiment +from aggregator.models import Aggregator + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsAggregatorOwner(BasePermission): + def get_object(self, pk): + try: + return Aggregator.objects.get(pk=pk) + except Aggregator.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("aggregator", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + aggregator = self.get_object(pk) + if not aggregator: + return False + if aggregator.owner.id == request.user.id: + return True + else: + return False + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("tid", None) + if not pk: + return False + training_exp = self.get_object(pk) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False diff --git a/server/aggregator_association/serializers.py b/server/aggregator_association/serializers.py new file mode 100644 index 000000000..0360f1ed5 --- /dev/null +++ b/server/aggregator_association/serializers.py @@ -0,0 +1,156 @@ +from rest_framework import serializers +from django.utils import timezone +from training.models import TrainingExperiment +from aggregator.models import Aggregator + +from .models import ExperimentAggregator +from .utils import latest_agg_associations +from signing.interface import verify_aggregator_csr, sign_csr + + +class ExperimentAggregatorListSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentAggregator + read_only_fields = ["initiated_by", "approved_at", "certificate"] + fields = "__all__" + + def validate(self, data): + exp_id = self.context["request"].data.get("training_exp") + aggregator = self.context["request"].data.get("aggregator") + approval_status = self.context["request"].data.get("approval_status") + csr = self.context["request"].data.get("signing_request") + + training_exp = TrainingExperiment.objects.get(pk=exp_id) + training_exp_state = training_exp.state + + if training_exp_state != "DEVELOPMENT": + raise serializers.ValidationError( + "Aggregator Association requests can be made only " + "on a development training experiment" + ) + training_exp_approval_status = training_exp.approval_status + if training_exp_approval_status != "APPROVED": + raise serializers.ValidationError( + "Association requests can be made only on an approved training experiment" + ) + aggregator_object = Aggregator.objects.get(pk=aggregator) + last_experiment_aggregator = ( + ExperimentAggregator.objects.filter( + training_exp__id=exp_id, aggregator__id=aggregator + ) + .order_by("-created_at") + .first() + ) + if not last_experiment_aggregator: + if approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject association request only if there are prior requests" + ) + else: + if approval_status == "PENDING": + if last_experiment_aggregator.approval_status != "REJECTED": + raise serializers.ValidationError( + "User can create a new request only if prior request is rejected" + ) + elif approval_status == "APPROVED": + raise serializers.ValidationError( + "User cannot create an approved association request" + ) + # approval_status == "REJECTED": + else: + if last_experiment_aggregator.approval_status != "APPROVED": + raise serializers.ValidationError( + "User can reject request only if prior request is approved" + ) + + # check if there is already an approved aggregator + # TODO: concurrency problem perhaps? if a user creates simultanuously two + # already APPROVED associations + experiment_aggregators = latest_agg_associations(exp_id) + approved_experiment_aggregators = experiment_aggregators.filter( + approval_status="APPROVED" + ) + if approved_experiment_aggregators.exists(): + raise serializers.ValidationError( + "This training experiment already has an aggregator" + ) + + valid_csr, reason = verify_aggregator_csr( + csr, aggregator_object, training_exp, self.context["request"] + ) + if not valid_csr: + raise serializers.ValidationError(reason) + + return data + + def create(self, validated_data): + approval_status = validated_data.get("approval_status", "PENDING") + if approval_status != "PENDING": + validated_data["approved_at"] = timezone.now() + else: + if ( + validated_data["aggregator"].owner.id + == validated_data["training_exp"].owner.id + ): + validated_data["approval_status"] = "APPROVED" + validated_data["approved_at"] = timezone.now() + csr = validated_data["signing_request"] + certificate = sign_csr(csr, validated_data["training_exp"]) + validated_data["certificate"] = certificate + return ExperimentAggregator.objects.create(**validated_data) + + +class AggregatorApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentAggregator + read_only_fields = ["initiated_by", "approved_at", "certificate"] + fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + "certificate", + ] + + def validate(self, data): + if not self.instance: + raise serializers.ValidationError("No aggregator association found") + last_approval_status = self.instance.approval_status + cur_approval_status = data["approval_status"] + if last_approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject only a pending request" + ) + initiated_user = self.instance.initiated_by + current_user = self.context["request"].user + if ( + last_approval_status != cur_approval_status + and cur_approval_status == "APPROVED" + ): + if current_user.id == initiated_user.id: + raise serializers.ValidationError( + "Same user cannot approve the association request" + ) + + # check if there is already an approved aggregator + experiment_aggregators = latest_agg_associations(self.instance.training_exp.id) + approved_experiment_aggregators = experiment_aggregators.filter( + approval_status="APPROVED" + ) + if approved_experiment_aggregators.exists(): + raise serializers.ValidationError( + "This training experiment already has an aggregator" + ) + return data + + def update(self, instance, validated_data): + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + if instance.approval_status == "APPROVED": + csr = instance.signing_request + certificate = sign_csr(csr, self.instance.training_exp.id) + instance.certificate = certificate + instance.save() + return instance diff --git a/server/aggregator_association/utils.py b/server/aggregator_association/utils.py new file mode 100644 index 000000000..6a447210d --- /dev/null +++ b/server/aggregator_association/utils.py @@ -0,0 +1,14 @@ +from django.db.models import OuterRef, Subquery +from .models import ExperimentAggregator + + +def latest_agg_associations(training_exp_id): + experiment_aggregators = ExperimentAggregator.objects.filter( + training_exp__id=training_exp_id + ) + latest_assocs = ( + experiment_aggregators.filter(aggregator=OuterRef("aggregator")) + .order_by("-created_at") + .values("id")[:1] + ) + return experiment_aggregators.filter(id__in=Subquery(latest_assocs)) diff --git a/server/aggregator_association/views.py b/server/aggregator_association/views.py new file mode 100644 index 000000000..846289e82 --- /dev/null +++ b/server/aggregator_association/views.py @@ -0,0 +1,76 @@ +from .models import ExperimentAggregator +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsAggregatorOwner, IsExpOwner +from .serializers import ( + ExperimentAggregatorListSerializer, + AggregatorApprovalSerializer, +) + + +class ExperimentAggregatorList(GenericAPIView): + permission_classes = [IsAdmin | IsAggregatorOwner] + serializer_class = ExperimentAggregatorListSerializer + queryset = "" + + def post(self, request, format=None): + """ + Associate a aggregator to a training_exp + """ + serializer = ExperimentAggregatorListSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(initiated_by=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class AggregatorApproval(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] + serializer_class = AggregatorApprovalSerializer + queryset = "" + + def get_object(self, aggregator_id, training_exp_id): + try: + return ExperimentAggregator.objects.filter( + aggregator__id=aggregator_id, training_exp__id=training_exp_id + ) + except ExperimentAggregator.DoesNotExist: + raise Http404 + + def get(self, request, pk, tid, format=None): + """ + Retrieve approval status of training_exp aggregator associations + """ + training_expaggregator = ( + self.get_object(pk, tid).order_by("-created_at").first() + ) + serializer = AggregatorApprovalSerializer(training_expaggregator) + return Response(serializer.data) + + def put(self, request, pk, tid, format=None): + """ + Update approval status of the last training_exp aggregator association + """ + training_expaggregator = ( + self.get_object(pk, tid).order_by("-created_at").first() + ) + serializer = AggregatorApprovalSerializer( + training_expaggregator, data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, tid, format=None): + """ + Delete a training_exp aggregator association + """ + training_expaggregator = self.get_object(pk, tid) + training_expaggregator.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/server/dataset/urls.py b/server/dataset/urls.py index 5aa23fd5a..ff07b54dd 100644 --- a/server/dataset/urls.py +++ b/server/dataset/urls.py @@ -1,6 +1,7 @@ from django.urls import path from . import views from benchmarkdataset import views as bviews +from traindataset_association import views as tviews app_name = "Dataset" @@ -11,4 +12,6 @@ path("/benchmarks//", bviews.DatasetApproval.as_view()), # path("/benchmarks/", bviews.DatasetBenchmarksList.as_view()), # NOTE: when activating this endpoint later, check permissions and write tests + path("training_experiments/", tviews.ExperimentDatasetList.as_view()), + path("/training_experiments//", tviews.DatasetApproval.as_view()), ] diff --git a/server/key_storage/__init__.py b/server/key_storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/key_storage/gcloud_secret_manager.py b/server/key_storage/gcloud_secret_manager.py new file mode 100644 index 000000000..d36edd069 --- /dev/null +++ b/server/key_storage/gcloud_secret_manager.py @@ -0,0 +1,11 @@ +class GcloudSecretStorage: + def __init__(self, filepath): + raise NotImplementedError + + def write(self, key, storage_id): + # NOTE: use one secret per deployment. + # store keys as secret versions + raise NotImplementedError + + def read(self, storage_id): + raise NotImplementedError diff --git a/server/key_storage/local.py b/server/key_storage/local.py new file mode 100644 index 000000000..00cfe60c3 --- /dev/null +++ b/server/key_storage/local.py @@ -0,0 +1,20 @@ +import os +from signing.cryptography.io import write_key, read_key + + +class LocalSecretStorage: + """NOT SUITABLE FOR PRODUCTION. it simply stores keys + in filesystem.""" + + def __init__(self, folderpath): + os.makedirs(folderpath, exist_ok=True) + self.folderpath = folderpath + + def write(self, key, storage_id): + filepath = os.path.join(self.folderpath, storage_id) + write_key(key, filepath) + + def read(self, storage_id): + filepath = os.path.join(self.folderpath, storage_id) + key = read_key(filepath) + return key diff --git a/server/medperf/settings.py b/server/medperf/settings.py index 75c6af066..77e97bb29 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -16,6 +16,8 @@ import environ import google.auth from google.cloud import secretmanager +from key_storage.gcloud_secret_manager import GcloudSecretStorage +from key_storage.local import LocalSecretStorage # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent @@ -90,6 +92,10 @@ "benchmarkmodel", "user", "result", + "training", + "aggregator", + "traindataset_association", + "aggregator_association", "rest_framework", "rest_framework.authtoken", "drf_spectacular", @@ -287,3 +293,9 @@ "JTI_CLAIM": None, # Currently expected auth tokens don't contain such a claim } TOKEN_USER_EMAIL_CLAIM = "https://medperf.org/email" + +if DEPLOY_ENV == "gcp-prod": + # TODO + KEY_STORAGE = GcloudSecretStorage("") +else: + KEY_STORAGE = LocalSecretStorage(os.path.join(BASE_DIR, "keys")) diff --git a/server/medperf/urls.py b/server/medperf/urls.py index be4e07dce..f739ae155 100644 --- a/server/medperf/urls.py +++ b/server/medperf/urls.py @@ -36,5 +36,7 @@ path("results/", include("result.urls", namespace=API_VERSION), name="result"), path("users/", include("user.urls", namespace=API_VERSION), name="users"), path("me/", include("utils.urls", namespace=API_VERSION), name="me"), + path("training/", include("training.urls", namespace=API_VERSION), name="training"), + path("aggregators/", include("aggregator.urls", namespace=API_VERSION), name="aggregator"), ])), ] diff --git a/server/signing/__init__.py b/server/signing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/signing/cryptography/__init__.py b/server/signing/cryptography/__init__.py new file mode 100644 index 000000000..b3f394d12 --- /dev/null +++ b/server/signing/cryptography/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""openfl.cryptography package.""" diff --git a/server/signing/cryptography/ca.py b/server/signing/cryptography/ca.py new file mode 100644 index 000000000..d651919a4 --- /dev/null +++ b/server/signing/cryptography/ca.py @@ -0,0 +1,150 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cryptography CA utilities.""" + +import datetime +import uuid +from typing import Tuple + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509.base import Certificate +from cryptography.x509.base import CertificateSigningRequest +from cryptography.x509.extensions import ExtensionNotFound +from cryptography.x509.name import Name +from cryptography.x509.oid import ExtensionOID +from cryptography.x509.oid import NameOID + + +def generate_root_cert( + common_name: str = "Simple Root CA", days_to_expiration: int = 365 +) -> Tuple[RSAPrivateKey, Certificate]: + """Generate_root_certificate.""" + now = datetime.datetime.utcnow() + expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) + + # Generate private key + root_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=3072, backend=default_backend() + ) + + # Generate public key + root_public_key = root_private_key.public_key() + builder = x509.CertificateBuilder() + subject = x509.Name( + [ + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Root CA"), + ] + ) + issuer = subject + builder = builder.subject_name(subject) + builder = builder.issuer_name(issuer) + + builder = builder.not_valid_before(now) + builder = builder.not_valid_after(now + expiration_delta) + builder = builder.serial_number(int(uuid.uuid4())) + builder = builder.public_key(root_public_key) + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + + # Sign the CSR + certificate = builder.sign( + private_key=root_private_key, + algorithm=hashes.SHA384(), + backend=default_backend(), + ) + + return root_private_key, certificate + + +def generate_signing_csr() -> Tuple[RSAPrivateKey, CertificateSigningRequest]: + """Generate signing CSR.""" + # Generate private key + signing_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=3072, backend=default_backend() + ) + + builder = x509.CertificateSigningRequestBuilder() + subject = x509.Name( + [ + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), + x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), + x509.NameAttribute(NameOID.COMMON_NAME, "Simple Signing CA"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Signing CA"), + ] + ) + builder = builder.subject_name(subject) + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + + # Sign the CSR + csr = builder.sign( + private_key=signing_private_key, + algorithm=hashes.SHA384(), + backend=default_backend(), + ) + + return signing_private_key, csr + + +def sign_certificate( + csr: CertificateSigningRequest, + issuer_private_key: RSAPrivateKey, + issuer_name: Name, + days_to_expiration: int = 365, + ca: bool = False, +) -> Certificate: + """ + Sign the incoming CSR request. + + Args: + csr : Certificate Signing Request object + issuer_private_key : Root CA private key if the request is for the signing + CA; Signing CA private key otherwise + issuer_name : x509 Name + days_to_expiration : int (365 days by default) + ca : Is this a certificate authority + """ + now = datetime.datetime.utcnow() + expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) + + builder = x509.CertificateBuilder() + builder = builder.subject_name(csr.subject) + builder = builder.issuer_name(issuer_name) + builder = builder.not_valid_before(now) + builder = builder.not_valid_after(now + expiration_delta) + builder = builder.serial_number(int(uuid.uuid4())) + builder = builder.public_key(csr.public_key()) + builder = builder.add_extension( + x509.BasicConstraints(ca=ca, path_length=None), + critical=True, + ) + try: + builder = builder.add_extension( + csr.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ).value, + critical=False, + ) + except ExtensionNotFound: + pass # Might not have alternative name + + signed_cert = builder.sign( + private_key=issuer_private_key, + algorithm=hashes.SHA384(), + backend=default_backend(), + ) + return signed_cert diff --git a/server/signing/cryptography/io.py b/server/signing/cryptography/io.py new file mode 100644 index 000000000..52bfc5e95 --- /dev/null +++ b/server/signing/cryptography/io.py @@ -0,0 +1,129 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cryptography IO utilities.""" + +import os +from hashlib import sha384 +from pathlib import Path +from typing import Tuple + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.x509.base import Certificate +from cryptography.x509.base import CertificateSigningRequest + + +def read_key(path: Path) -> RSAPrivateKey: + """ + Read private key. + + Args: + path : Path (pathlib) + + Returns: + private_key + """ + with open(path, 'rb') as f: + pem_data = f.read() + + signing_key = load_pem_private_key(pem_data, password=None) + # TODO: replace assert with exception / sys.exit + assert (isinstance(signing_key, rsa.RSAPrivateKey)) + return signing_key + + +def write_key(key: RSAPrivateKey, path: Path) -> None: + """ + Write private key. + + Args: + key : RSA private key object + path : Path (pathlib) + + """ + def key_opener(path, flags): + return os.open(path, flags, mode=0o600) + + with open(path, 'wb', opener=key_opener) as f: + f.write(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + )) + + +def read_crt(path: Path) -> Certificate: + """ + Read signed TLS certificate. + + Args: + path : Path (pathlib) + + Returns: + Cryptography TLS Certificate object + """ + with open(path, 'rb') as f: + pem_data = f.read() + + certificate = x509.load_pem_x509_certificate(pem_data) + # TODO: replace assert with exception / sys.exit + assert (isinstance(certificate, x509.Certificate)) + return certificate + + +def write_crt(certificate: Certificate, path: Path) -> None: + """ + Write cryptography certificate / csr. + + Args: + certificate : cryptography csr / certificate object + path : Path (pathlib) + + Returns: + Cryptography TLS Certificate object + """ + with open(path, 'wb') as f: + f.write(certificate.public_bytes( + encoding=serialization.Encoding.PEM, + )) + + +def read_csr(path: Path) -> Tuple[CertificateSigningRequest, str]: + """ + Read certificate signing request. + + Args: + path : Path (pathlib) + + Returns: + Cryptography CSR object + """ + with open(path, 'rb') as f: + pem_data = f.read() + + csr = x509.load_pem_x509_csr(pem_data) + # TODO: replace assert with exception / sys.exit + assert (isinstance(csr, x509.CertificateSigningRequest)) + return csr, get_csr_hash(csr) + + +def get_csr_hash(certificate: CertificateSigningRequest) -> str: + """ + Get hash of cryptography certificate. + + Args: + certificate : Cryptography CSR object + + Returns: + Hash of cryptography certificate / csr + """ + hasher = sha384() + encoded_bytes = certificate.public_bytes( + encoding=serialization.Encoding.PEM, + ) + hasher.update(encoded_bytes) + return hasher.hexdigest() diff --git a/server/signing/cryptography/participant.py b/server/signing/cryptography/participant.py new file mode 100644 index 000000000..d6e94712b --- /dev/null +++ b/server/signing/cryptography/participant.py @@ -0,0 +1,72 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cryptography participant utilities.""" +from typing import Tuple + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509.base import CertificateSigningRequest +from cryptography.x509.oid import NameOID + + +def generate_csr(common_name: str, + server: bool = False) -> Tuple[RSAPrivateKey, CertificateSigningRequest]: + """Issue certificate signing request for server and client.""" + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=3072, + backend=default_backend() + ) + + builder = x509.CertificateSigningRequestBuilder() + subject = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ]) + builder = builder.subject_name(subject) + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True, + ) + if server: + builder = builder.add_extension( + x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.SERVER_AUTH]), + critical=True + ) + + else: + builder = builder.add_extension( + x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.CLIENT_AUTH]), + critical=True + ) + + builder = builder.add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + data_encipherment=False, + key_agreement=False, + content_commitment=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False + ), + critical=True + ) + + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName(common_name)]), + critical=False + ) + + # Sign the CSR + csr = builder.sign( + private_key=private_key, algorithm=hashes.SHA384(), + backend=default_backend() + ) + + return private_key, csr diff --git a/server/signing/cryptography/utils.py b/server/signing/cryptography/utils.py new file mode 100644 index 000000000..03f9eb940 --- /dev/null +++ b/server/signing/cryptography/utils.py @@ -0,0 +1,14 @@ +from cryptography.hazmat.primitives import serialization +from cryptography import x509 + + +def cert_to_str(cert): + return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") + + +def str_to_cert(cert_str): + return x509.load_pem_x509_certificate(cert_str.encode("utf-8")) + + +def str_to_csr(csr_str): + return x509.load_pem_x509_csr(csr_str.encode("utf-8")) diff --git a/server/signing/interface.py b/server/signing/interface.py new file mode 100644 index 000000000..024931da7 --- /dev/null +++ b/server/signing/interface.py @@ -0,0 +1,67 @@ +from cryptography import x509 + +from django.conf import settings +from .cryptography.ca import generate_root_cert, sign_certificate +from .cryptography.utils import cert_to_str, str_to_cert, str_to_csr +from training.models import TrainingExperiment + + +def __get_experiment_key_pair(training_exp_id): + exp = TrainingExperiment.objects.get(pk=training_exp_id) + private_key_id = exp.private_key + private_key = settings.KEY_STORAGE.read(private_key_id) + public_key_str = exp.public_key + public_key = str_to_cert(public_key_str) + return private_key, public_key + + +def generate_key_pair(training_exp_id): + # TODO: do we need to destroy the keys at some point? + ca_common_name = f"training_{training_exp_id}" + root_private_key, certificate = generate_root_cert(ca_common_name) + + # store private key + storage_id = ca_common_name + settings.KEY_STORAGE.write(root_private_key, storage_id) + + # public key to str + public_key_str = cert_to_str(certificate) + return storage_id, public_key_str + + +def sign_csr(csr_str, training_exp_id): + # Load CSR + csr = str_to_csr(csr_str) + + # load signing key and crt + signing_key, signing_crt = __get_experiment_key_pair(training_exp_id) + + # sign + signed_cert = sign_certificate(csr, signing_key, signing_crt.subject) + + # cert as str + cert_str = cert_to_str(signed_cert) + + return cert_str + + +def verify_dataset_csr(csr_str, dataset_object, training_exp): + # TODO? + try: + csr = str_to_csr(csr_str) + except ValueError as e: + return False, str(e) + if not isinstance(csr, x509.CertificateSigningRequest): + return False, "Invalid CSR format" + return True, "" + + +def verify_aggregator_csr(csr_str, aggregator_object, training_exp, request): + # TODO? + try: + csr = str_to_csr(csr_str) + except ValueError as e: + return False, str(e) + if not isinstance(csr, x509.CertificateSigningRequest): + return False, "Invalid CSR format" + return True, "" diff --git a/server/testing_medperf.sh b/server/testing_medperf.sh new file mode 100644 index 000000000..861a3dea4 --- /dev/null +++ b/server/testing_medperf.sh @@ -0,0 +1,69 @@ +# TODO: remove me from the repo + +# setup dev server or reset db +bash reset_db.sh +sudo rm -rf keys +sudo rm -rf /home/hasan/.medperf +# seed +python seed.py + + +medperf profile ls +medperf profile activate local +medperf profile ls + +# move folder as a created dataset +cp -r /home/hasan/work/openfl_ws/9d56e799a9e63a6c3ced056ebd67eb6381483381 /home/hasan/.medperf/localhost_8000/data/ + +# login +medperf auth login -e testbo@example.com + +# register mlcube +medperf mlcube submit -n testfl \ + -m https://storage.googleapis.com/medperf-storage/testfl/mlcube-cpu.yaml?v=1 \ + -p https://storage.googleapis.com/medperf-storage/testfl/parameters-cpu.yaml \ + -a https://storage.googleapis.com/medperf-storage/testfl/additional.tar.gz + + +# register training exp +medperf training submit -n testtrain -d testtrain -p 1 -m 5 + +# mark as approved +curl -sk -X PUT https://localhost:8000/api/v0/training/1/ -d '{"approval_status": "APPROVED"}' -H 'Content-Type: application/json' -H "Authorization: Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.PbAxtzBxPfipnuYGPx90P2_K-2V7jPSdPEhzHEW6u4KnUQU8Gul6xrwLsGlgdD19A6EzUtgfQxW2Lk2OITcOD0nbXcjUgPyduLozMXDdTwom19429g7Q5eWOppWdMImirX3OygWaqx587Q_OL73HZuCjFcEWwyGnhB62oruVRcM6uDWz4xVmGcAwdtMzCBYvQj9_C-Hnt9IYPgnKesXPr_AP98-bdQx2EBahXtQW1HaARgabZp3SLaCDY9I6h91B7NQ-PDWpuDxd0UamHSaq9dNPbd0SsR6ajl80wOKQaZF3be_TKJW0e0l7L4tnsbbSW23fR1utSH2PlNFPBx3uGGe2Aqirdq16fAWqvDNO8-kiVRpeikp0ze17lTYqtw2-GZIxXyc8rG-NPxz7R5lMg7ARu99e5nLGFHpV5sMNUoXKx5zoPO7Y7cO5mdzm0C_2DARB7imagKsL5eLc5fcYDEZBl0FtkDgT_CY3FEuH_X3DgPwEP6wE2IFGnU1zEXtuNd1XSUxvxxZ0_afoX54qNuz3m9qzAKuYJkkziiApdIPE_bXX2ox3-Z_Q5RfqvtLRJoE64FaOMr_6xCq_77hpPDpWACQaXCwn736-Jl8nP1HcGvdDa980dzKaih4mQ-FtFZ8xhMXU7jA_Bur9e2tg51TxBzAyd4t4NNk-gYaSUPU" + +# register aggregator +medperf auth login -e testmo@example.com +medperf aggregator submit -n testagg -a hasan-HP-ZBook-15-G3 -p 50273 + +# associate aggregator +medperf aggregator associate -a 1 -t 1 + + +# register dataset +medperf auth login -e testdo@example.com +medperf dataset submit -d 9d56e799a9e63a6c3ced056ebd67eb6381483381 + +# associate dataset +medperf training associate_dataset -t 1 -d 1 + +# approve associations +medperf auth login -e testbo@example.com +medperf training approve_association -t 1 -a 1 +medperf training approve_association -t 1 -d 1 + + +# test nonimportant stuff +# medperf training ls +# medperf aggregator ls +# medperf training view 1 +# medperf aggregator view 1 +# medperf training list_associations + +# lock experiment +medperf training lock -t 1 + +# # start aggregator +gnome-terminal -- bash -c "medperf aggregator start -a 1 -t 1; bash" + +# # start collaborator +medperf training run -d 1 -t 1 diff --git a/server/testing_miccai.sh b/server/testing_miccai.sh new file mode 100644 index 000000000..70554af19 --- /dev/null +++ b/server/testing_miccai.sh @@ -0,0 +1,153 @@ +# TODO: remove me from the repo + +# First, run the local server +# cd ~/medperf/server +# sh setup-dev-server.sh +# go to another terminal + +cd .. +# # TODO: reset +# bash reset_db.sh +# sudo rm -rf keys +# sudo rm -rf ~/.medperf + +# TODO: seed +# python seed.py --demo benchmark + +# TODO: download data +# wget https://storage.googleapis.com/medperf-storage/testfl/data/col1.tar.gz +# tar -xf col1.tar.gz +# wget https://storage.googleapis.com/medperf-storage/testfl/data/col2.tar.gz +# tar -xf col2.tar.gz +# wget https://storage.googleapis.com/medperf-storage/testfl/data/test.tar.gz +# tar -xf test.tar.gz +# rm col1.tar.gz +# rm col2.tar.gz +# rm test.tar.gz + +# TODO: activate local profile +# medperf profile activate local + +# login +medperf auth login -e modelowner@example.com + +# register prep mlcube +medperf mlcube submit -n prep \ + -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_prep.yaml + +# register training mlcube +medperf mlcube submit -n testfl \ + -m https://storage.googleapis.com/medperf-storage/testfl/mlcube-cpu.yaml?v=2 \ + -p https://storage.googleapis.com/medperf-storage/testfl/parameters-miccai.yaml \ + -a https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz + +# register training exp +medperf training submit -n testtrain -d testtrain -p 1 -m 2 + +# mark as approved +bash admin_training_approval.sh + +# register aggregator +medperf aggregator submit -n testagg -a $(hostname --fqdn) -p 50273 + +# associate aggregator +medperf aggregator associate -a 1 -t 1 -y + + +# register dataset +medperf auth login -e traincol1@example.com +medperf dataset create -p 1 -d datasets/col1 -l datasets/col1 --name col1 --description col1data --location col1location +medperf dataset submit -d $(medperf dataset ls | grep col1 | tr -s " " | cut -d " " -f 1) -y + +# associate dataset +medperf training associate_dataset -t 1 -d 1 -y + +# shortcut +bash shortcut.sh + +# approve associations +medperf auth login -e modelowner@example.com +medperf training approve_association -t 1 -d 1 +medperf training approve_association -t 1 -d 2 + +# lock experiment +medperf training lock -t 1 + +# # start aggregator +gnome-terminal -- bash -c "medperf aggregator start -a 1 -t 1; bash" + +sleep 5 + +# # start collaborator 1 +medperf auth login -e traincol1@example.com +gnome-terminal -- bash -c "medperf training run -d 1 -t 1; bash" + +sleep 5 + +# # start collaborator 2 +medperf auth login -e traincol2@example.com +medperf training run -d 2 -t 1 + + +############### eval starts + + +# submit reference model +medperf auth login -e benchmarkowner@example.com +medperf mlcube submit -n refmodel \ + -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_other.yaml + +# submit metrics mlcube +medperf mlcube submit -n metrics \ + -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_metrics.yaml \ + -p https://storage.googleapis.com/medperf-storage/testfl/parameters_metrics.yaml + +# submit benchmark metadata +medperf benchmark submit --name pathmnistbmk --description pathmnistbmk \ + --demo-url https://storage.googleapis.com/medperf-storage/testfl/data/sample.tar.gz \ + -p 1 -m 3 -e 4 + +# mark as approved +bash admin_benchmark_approval.sh + +# submit trained model +medperf auth login -e modelowner@example.com +medperf mlcube submit -n trained \ + -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_trained.yaml + +# participatemedperf benchmark submit +medperf mlcube associate -b 1 -m 5 -y + +# submit inference dataset +medperf auth login -e testcol@example.com +medperf dataset create -p 1 -d datasets/test -l datasets/test --name testdata --description testdata --location testdata +medperf dataset submit -d $(medperf dataset ls | grep test | tr -s " " | cut -d " " -f 1) -y + +# associate dataset +medperf dataset associate -b 1 -d 3 -y + +# approve associations +medperf auth login -e benchmarkowner@example.com +medperf association approve -b 1 -m 5 +medperf association approve -b 1 -d 3 + +# run inference +medperf auth login -e testcol@example.com +medperf benchmark run -b 1 -d 3 + +# submit result +medperf result submit -r b1m5d3 -y +medperf result submit -r b1m3d3 -y + + +# read results +medperf auth login -e benchmarkowner@example.com +medperf result view -b 1 + +############ test other stuff +medperf auth login -e modelowner@example.com +medperf training ls +medperf aggregator ls +medperf training view 1 +medperf aggregator view 1 +medperf training list_associations \ No newline at end of file diff --git a/server/testing_miccai_shortcut.sh b/server/testing_miccai_shortcut.sh new file mode 100644 index 000000000..a611bc7df --- /dev/null +++ b/server/testing_miccai_shortcut.sh @@ -0,0 +1,7 @@ +# register dataset +medperf auth login -e traincol2@example.com +medperf dataset create -p 1 -d ../../datasets_folder_final/col2 -l ../../datasets_folder_final/col1 --name col2 --description col2data --location col2location +medperf dataset submit -d 54ea1643f6006ead7e8517cd65fd5275f99abe7349895be25bd8485761cde088 -y + +# associate dataset +medperf training associate_dataset -t 1 -d 2 -y diff --git a/server/traindataset_association/__init__.py b/server/traindataset_association/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/traindataset_association/admin.py b/server/traindataset_association/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/traindataset_association/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/traindataset_association/apps.py b/server/traindataset_association/apps.py new file mode 100644 index 000000000..680686727 --- /dev/null +++ b/server/traindataset_association/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TraindatasetAssociationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'traindataset_association' diff --git a/server/traindataset_association/migrations/0001_initial.py b/server/traindataset_association/migrations/0001_initial.py new file mode 100644 index 000000000..23e4fc840 --- /dev/null +++ b/server/traindataset_association/migrations/0001_initial.py @@ -0,0 +1,36 @@ +# Generated by Django 3.2.20 on 2023-09-29 01:02 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('dataset', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='ExperimentDataset', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('certificate', models.TextField(blank=True)), + ('signing_request', models.TextField()), + ('metadata', models.JSONField(default=dict)), + ('approval_status', models.CharField(choices=[('PENDING', 'PENDING'), ('APPROVED', 'APPROVED'), ('REJECTED', 'REJECTED')], default='PENDING', max_length=100)), + ('approved_at', models.DateTimeField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('modified_at', models.DateTimeField(auto_now=True)), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to='dataset.dataset')), + ('initiated_by', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['created_at'], + }, + ), + ] diff --git a/server/traindataset_association/migrations/0002_experimentdataset_training_exp.py b/server/traindataset_association/migrations/0002_experimentdataset_training_exp.py new file mode 100644 index 000000000..807ed44ae --- /dev/null +++ b/server/traindataset_association/migrations/0002_experimentdataset_training_exp.py @@ -0,0 +1,22 @@ +# Generated by Django 3.2.20 on 2023-09-29 01:02 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('traindataset_association', '0001_initial'), + ('training', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='experimentdataset', + name='training_exp', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='training.trainingexperiment'), + ), + ] diff --git a/server/traindataset_association/migrations/__init__.py b/server/traindataset_association/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/traindataset_association/models.py b/server/traindataset_association/models.py new file mode 100644 index 000000000..460e8a5db --- /dev/null +++ b/server/traindataset_association/models.py @@ -0,0 +1,29 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class ExperimentDataset(models.Model): + MODEL_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + certificate = models.TextField(blank=True) + signing_request = models.TextField() + dataset = models.ForeignKey("dataset.Dataset", on_delete=models.PROTECT) + training_exp = models.ForeignKey( + "training.TrainingExperiment", on_delete=models.CASCADE + ) + initiated_by = models.ForeignKey(User, on_delete=models.PROTECT) + metadata = models.JSONField(default=dict) + approval_status = models.CharField( + choices=MODEL_STATUS, max_length=100, default="PENDING" + ) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/traindataset_association/permissions.py b/server/traindataset_association/permissions.py new file mode 100644 index 000000000..898122730 --- /dev/null +++ b/server/traindataset_association/permissions.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import BasePermission +from training.models import TrainingExperiment +from dataset.models import Dataset + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsDatasetOwner(BasePermission): + def get_object(self, pk): + try: + return Dataset.objects.get(pk=pk) + except Dataset.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("dataset", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + dataset = self.get_object(pk) + if not dataset: + return False + if dataset.owner.id == request.user.id: + return True + else: + return False + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("tid", None) + if not pk: + return False + training_experiment = self.get_object(pk) + if not training_experiment: + return False + if training_experiment.owner.id == request.user.id: + return True + else: + return False diff --git a/server/traindataset_association/serializers.py b/server/traindataset_association/serializers.py new file mode 100644 index 000000000..ffc74414d --- /dev/null +++ b/server/traindataset_association/serializers.py @@ -0,0 +1,136 @@ +from rest_framework import serializers +from django.utils import timezone +from training.models import TrainingExperiment +from dataset.models import Dataset + +from .models import ExperimentDataset +from signing.interface import verify_dataset_csr, sign_csr + + +class ExperimentDatasetListSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentDataset + read_only_fields = ["initiated_by", "approved_at", "certificate"] + fields = "__all__" + + def validate(self, data): + exp_id = self.context["request"].data.get("training_exp") + dataset = self.context["request"].data.get("dataset") + approval_status = self.context["request"].data.get("approval_status") + csr = self.context["request"].data.get("signing_request") + + training_exp = TrainingExperiment.objects.get(pk=exp_id) + training_exp_state = training_exp.state + + if training_exp_state != "DEVELOPMENT": + raise serializers.ValidationError( + "Dataset Association requests can be made only " + "on a development training experiment" + ) + training_exp_approval_status = training_exp.approval_status + if training_exp_approval_status != "APPROVED": + raise serializers.ValidationError( + "Association requests can be made only on an approved training experiment" + ) + dataset_object = Dataset.objects.get(pk=dataset) + dataset_state = dataset_object.state + if dataset_state != "OPERATION": + raise serializers.ValidationError( + "Association requests can be made only on an operational dataset" + ) + last_experiment_dataset = ( + ExperimentDataset.objects.filter( + training_exp__id=exp_id, dataset__id=dataset + ) + .order_by("-created_at") + .first() + ) + if not last_experiment_dataset: + if approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject association request only if there are prior requests" + ) + else: + if approval_status == "PENDING": + if last_experiment_dataset.approval_status != "REJECTED": + raise serializers.ValidationError( + "User can create a new request only if prior request is rejected" + ) + elif approval_status == "APPROVED": + raise serializers.ValidationError( + "User cannot create an approved association request" + ) + # approval_status == "REJECTED": + else: + if last_experiment_dataset.approval_status != "APPROVED": + raise serializers.ValidationError( + "User can reject request only if prior request is approved" + ) + + valid_csr, reason = verify_dataset_csr(csr, dataset_object, training_exp) + if not valid_csr: + raise serializers.ValidationError(reason) + + return data + + def create(self, validated_data): + approval_status = validated_data.get("approval_status", "PENDING") + if approval_status != "PENDING": + validated_data["approved_at"] = timezone.now() + else: + if ( + validated_data["dataset"].owner.id + == validated_data["training_exp"].owner.id + ): + validated_data["approval_status"] = "APPROVED" + validated_data["approved_at"] = timezone.now() + csr = validated_data["signing_request"] + certificate = sign_csr(csr, validated_data["training_exp"]) + validated_data["certificate"] = certificate + return ExperimentDataset.objects.create(**validated_data) + + +class DatasetApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentDataset + read_only_fields = ["initiated_by", "approved_at", "certificate"] + fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + "certificate", + ] + + def validate(self, data): + if not self.instance: + raise serializers.ValidationError("No dataset association found") + last_approval_status = self.instance.approval_status + cur_approval_status = data["approval_status"] + if last_approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject only a pending request" + ) + initiated_user = self.instance.initiated_by + current_user = self.context["request"].user + if ( + last_approval_status != cur_approval_status + and cur_approval_status == "APPROVED" + ): + if current_user.id == initiated_user.id: + raise serializers.ValidationError( + "Same user cannot approve the association request" + ) + return data + + def update(self, instance, validated_data): + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + if instance.approval_status == "APPROVED": + csr = instance.signing_request + certificate = sign_csr(csr, self.instance.training_exp.id) + instance.certificate = certificate + instance.save() + return instance diff --git a/server/traindataset_association/utils.py b/server/traindataset_association/utils.py new file mode 100644 index 000000000..7143d351a --- /dev/null +++ b/server/traindataset_association/utils.py @@ -0,0 +1,14 @@ +from django.db.models import OuterRef, Subquery +from .models import ExperimentDataset + + +def latest_data_associations(training_exp_id): + experiment_datasets = ExperimentDataset.objects.filter( + training_exp__id=training_exp_id + ) + latest_assocs = ( + experiment_datasets.filter(dataset=OuterRef("dataset")) + .order_by("-created_at") + .values("id")[:1] + ) + return experiment_datasets.filter(id__in=Subquery(latest_assocs)) diff --git a/server/traindataset_association/views.py b/server/traindataset_association/views.py new file mode 100644 index 000000000..0496a036e --- /dev/null +++ b/server/traindataset_association/views.py @@ -0,0 +1,72 @@ +from .models import ExperimentDataset +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsDatasetOwner, IsExpOwner +from .serializers import ( + ExperimentDatasetListSerializer, + DatasetApprovalSerializer, +) + + +class ExperimentDatasetList(GenericAPIView): + permission_classes = [IsAdmin | IsDatasetOwner] + serializer_class = ExperimentDatasetListSerializer + queryset = "" + + def post(self, request, format=None): + """ + Associate a dataset to a training_exp + """ + serializer = ExperimentDatasetListSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(initiated_by=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class DatasetApproval(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner | IsDatasetOwner] + serializer_class = DatasetApprovalSerializer + queryset = "" + + def get_object(self, dataset_id, training_exp_id): + try: + return ExperimentDataset.objects.filter( + dataset__id=dataset_id, training_exp__id=training_exp_id + ) + except ExperimentDataset.DoesNotExist: + raise Http404 + + def get(self, request, pk, tid, format=None): + """ + Retrieve approval status of training_exp dataset associations + """ + training_expdataset = self.get_object(pk, tid).order_by("-created_at").first() + serializer = DatasetApprovalSerializer(training_expdataset) + return Response(serializer.data) + + def put(self, request, pk, tid, format=None): + """ + Update approval status of the last training_exp dataset association + """ + training_expdataset = self.get_object(pk, tid).order_by("-created_at").first() + serializer = DatasetApprovalSerializer( + training_expdataset, data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, tid, format=None): + """ + Delete a training_exp dataset association + """ + training_expdataset = self.get_object(pk, tid) + training_expdataset.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/server/training/__init__.py b/server/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/training/admin.py b/server/training/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/training/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/training/apps.py b/server/training/apps.py new file mode 100644 index 000000000..8051e6caf --- /dev/null +++ b/server/training/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TrainingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'training' diff --git a/server/training/migrations/0001_initial.py b/server/training/migrations/0001_initial.py new file mode 100644 index 000000000..c17bdb4a3 --- /dev/null +++ b/server/training/migrations/0001_initial.py @@ -0,0 +1,46 @@ +# Generated by Django 3.2.20 on 2023-09-29 01:02 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('mlcube', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='TrainingExperiment', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=20, unique=True)), + ('description', models.CharField(blank=True, max_length=100)), + ('docs_url', models.CharField(blank=True, max_length=100)), + ('demo_dataset_tarball_url', models.CharField(blank=True, max_length=256)), + ('demo_dataset_tarball_hash', models.CharField(max_length=100)), + ('demo_dataset_generated_uid', models.CharField(max_length=128)), + ('metadata', models.JSONField(blank=True, default=dict, null=True)), + ('state', models.CharField(choices=[('DEVELOPMENT', 'DEVELOPMENT'), ('OPERATION', 'OPERATION')], default='DEVELOPMENT', max_length=100)), + ('is_valid', models.BooleanField(default=True)), + ('approval_status', models.CharField(choices=[('PENDING', 'PENDING'), ('APPROVED', 'APPROVED'), ('REJECTED', 'REJECTED')], default='PENDING', max_length=100)), + ('private_key', models.CharField(blank=True, max_length=100)), + ('public_key', models.TextField(blank=True)), + ('user_metadata', models.JSONField(blank=True, default=dict, null=True)), + ('approved_at', models.DateTimeField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('modified_at', models.DateTimeField(auto_now=True)), + ('data_preparation_mlcube', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='training_exp', to='mlcube.mlcube')), + ('fl_mlcube', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='fl_mlcube', to='mlcube.mlcube')), + ('owner', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['modified_at'], + }, + ), + ] diff --git a/server/training/migrations/__init__.py b/server/training/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/training/models.py b/server/training/models.py new file mode 100644 index 000000000..d86651f53 --- /dev/null +++ b/server/training/models.py @@ -0,0 +1,57 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class TrainingExperiment(models.Model): + EXP_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + STATES = ( + ("DEVELOPMENT", "DEVELOPMENT"), + ("OPERATION", "OPERATION"), + ) + + name = models.CharField(max_length=20, unique=True) + description = models.CharField(max_length=100, blank=True) + docs_url = models.CharField(max_length=100, blank=True) + owner = models.ForeignKey(User, on_delete=models.PROTECT) + demo_dataset_tarball_url = models.CharField(max_length=256, blank=True) + demo_dataset_tarball_hash = models.CharField(max_length=100) + demo_dataset_generated_uid = models.CharField(max_length=128) + data_preparation_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="training_exp", + ) + fl_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="fl_mlcube", + ) + + metadata = models.JSONField(default=dict, blank=True, null=True) + # TODO: consider if we want to enable restarts and epochs/"fresh restarts" + state = models.CharField(choices=STATES, max_length=100, default="DEVELOPMENT") + is_valid = models.BooleanField(default=True) + approval_status = models.CharField( + choices=EXP_STATUS, max_length=100, default="PENDING" + ) + private_key = models.CharField(max_length=100, blank=True) + public_key = models.TextField(blank=True) + # TODO: ensure unique, but allow blank (how?) + # TODO: rethink if keys are always needed + + user_metadata = models.JSONField(default=dict, blank=True, null=True) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return self.name + + class Meta: + ordering = ["modified_at"] diff --git a/server/training/permissions.py b/server/training/permissions.py new file mode 100644 index 000000000..98e59e048 --- /dev/null +++ b/server/training/permissions.py @@ -0,0 +1,27 @@ +from rest_framework.permissions import BasePermission +from .models import TrainingExperiment + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + training_exp = self.get_object(pk) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False diff --git a/server/training/serializers.py b/server/training/serializers.py new file mode 100644 index 000000000..42d96379c --- /dev/null +++ b/server/training/serializers.py @@ -0,0 +1,97 @@ +from rest_framework import serializers +from .models import TrainingExperiment +from signing.interface import generate_key_pair +from django.utils import timezone + + +class WriteTrainingExperimentSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingExperiment + exclude = ["private_key"] + read_only_fields = [ + "owner", + "private_key", + "public_key", + "approved_at", + "approval_status", + ] + + def validate(self, data): + owner = self.context["request"].user + pending_experiments = TrainingExperiment.objects.filter( + owner=owner, approval_status="PENDING" + ) + if len(pending_experiments) > 0: + raise serializers.ValidationError( + "User can own at most one pending experiment" + ) + return data + + def save(self, **kwargs): + super().save(**kwargs) + + # TODO: move key generation after admin approval? YES + # TODO: use atomic transaction + private_key_id, public_key = generate_key_pair(self.instance.id) + self.instance.private_key = private_key_id + self.instance.public_key = public_key + self.instance.save() + + return self.instance + + +class ReadTrainingExperimentSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingExperiment + exclude = ["private_key"] + + def update(self, instance, validated_data): + if ( + instance.approval_status != "PENDING" + and "approval_status" in validated_data + and validated_data["approval_status"] == "APPROVED" + ): + instance.approved_at = timezone.now() + for k, v in validated_data.items(): + setattr(instance, k, v) + instance.save() + return instance + + def validate(self, data): + if "approval_status" in data: + if ( + data["approval_status"] == "PENDING" + and self.instance.approval_status != "PENDING" + ): + pending_experiments = TrainingExperiment.objects.filter( + owner=self.instance.owner, approval_status="PENDING" + ) + if len(pending_experiments) > 0: + raise serializers.ValidationError( + "User can own at most one pending experiment" + ) + + editable_fields = [ + "is_valid", + "user_metadata", + "approval_status", + "demo_dataset_tarball_url", + ] + if self.instance.state == "DEVELOPMENT": + editable_fields.append("state") + + for k, v in data.items(): + if k not in editable_fields: + if v != getattr(self.instance, k): + raise serializers.ValidationError( + "User cannot update non editable fields" + ) + if ( + "state" in data + and data["state"] == "OPERATION" + and self.instance.state == "DEVELOPMENT" + ): + # TODO: check if there is an approved aggregator other wise raise + # and at least one approved dataset?? + pass + return data diff --git a/server/training/urls.py b/server/training/urls.py new file mode 100644 index 000000000..ff298550b --- /dev/null +++ b/server/training/urls.py @@ -0,0 +1,11 @@ +from django.urls import path +from . import views + +app_name = "training" + +urlpatterns = [ + path("", views.TrainingExperimentList.as_view()), + path("/", views.TrainingExperimentDetail.as_view()), + path("/datasets/", views.TrainingDatasetList.as_view()), + path("/aggregator/", views.GetAggregator.as_view()), +] diff --git a/server/training/views.py b/server/training/views.py new file mode 100644 index 000000000..72ebc6828 --- /dev/null +++ b/server/training/views.py @@ -0,0 +1,115 @@ +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .models import TrainingExperiment +from .serializers import ( + WriteTrainingExperimentSerializer, + ReadTrainingExperimentSerializer, +) +from .permissions import IsAdmin, IsExpOwner +from dataset.serializers import DatasetSerializer +from aggregator.serializers import AggregatorSerializer +from drf_spectacular.utils import extend_schema +from aggregator_association.utils import latest_agg_associations +from traindataset_association.utils import latest_data_associations + + +class TrainingExperimentList(GenericAPIView): + serializer_class = WriteTrainingExperimentSerializer + queryset = "" + + @extend_schema(operation_id="training_retrieve_all") + def get(self, request, format=None): + """ + List all training experiments + """ + training_exps = TrainingExperiment.objects.all() + training_exps = self.paginate_queryset(training_exps) + serializer = WriteTrainingExperimentSerializer(training_exps, many=True) + return self.get_paginated_response(serializer.data) + + def post(self, request, format=None): + """ + Create a new TrainingExperiment + """ + serializer = WriteTrainingExperimentSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class TrainingExperimentDetail(GenericAPIView): + serializer_class = ReadTrainingExperimentSerializer + queryset = "" + + def get_permissions(self): + if self.request.method == "PUT": + self.permission_classes = [IsAdmin | IsExpOwner] + return super(self.__class__, self).get_permissions() + + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve a TrainingExperiment instance. + """ + training_exp = self.get_object(pk) + serializer = ReadTrainingExperimentSerializer(training_exp) + return Response(serializer.data) + + def put(self, request, pk, format=None): + """ + Update a TrainingExperiment instance. + """ + training_exp = self.get_object(pk) + serializer = ReadTrainingExperimentSerializer( + training_exp, data=request.data, partial=True + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class TrainingDatasetList(GenericAPIView): + serializer_class = DatasetSerializer + queryset = "" + + def get(self, request, pk, format=None): + """ + Retrieve datasets associated with a training_exp instance. + """ + experiment_datasets = latest_data_associations(pk) + experiment_datasets = experiment_datasets.filter(approval_status="APPROVED") + datasets = [exp_dset.dataset for exp_dset in experiment_datasets] + datasets = self.paginate_queryset(datasets) + serializer = DatasetSerializer(datasets, many=True) + return self.get_paginated_response(serializer.data) + + +class GetAggregator(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + def get(self, request, pk, format=None): + """ + Retrieve aggregator associated with a training exp instance. + """ + experiment_aggregators = latest_agg_associations(pk) + experiment_aggregators = experiment_aggregators.filter( + approval_status="APPROVED" + ) + aggregators = [exp_agg.aggregator for exp_agg in experiment_aggregators] + if aggregators: + serializer = AggregatorSerializer(aggregators[0]) + return Response(serializer.data) + return Response({}, status=status.HTTP_400_BAD_REQUEST) diff --git a/server/utils/urls.py b/server/utils/urls.py index 736505ae6..662a5e504 100644 --- a/server/utils/urls.py +++ b/server/utils/urls.py @@ -11,4 +11,11 @@ path("results/", views.ModelResultList.as_view()), path("datasets/associations/", views.DatasetAssociationList.as_view()), path("mlcubes/associations/", views.MlCubeAssociationList.as_view()), + path("training/", views.TrainingExperimentList.as_view()), + path("aggregators/", views.AggregatorList.as_view()), + path( + "datasets/training_associations/", + views.DatasetTrainingAssociationList.as_view(), + ), + path("aggregators/associations/", views.AggregatorAssociationList.as_view()), ] diff --git a/server/utils/views.py b/server/utils/views.py index 7b214e48d..e1b722196 100644 --- a/server/utils/views.py +++ b/server/utils/views.py @@ -19,6 +19,14 @@ from rest_framework.permissions import AllowAny from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import serializers +from training.models import TrainingExperiment +from training.serializers import ReadTrainingExperimentSerializer +from aggregator.models import Aggregator +from aggregator.serializers import AggregatorSerializer +from traindataset_association.models import ExperimentDataset +from traindataset_association.serializers import ExperimentDatasetListSerializer +from aggregator_association.models import ExperimentAggregator +from aggregator_association.serializers import ExperimentAggregatorListSerializer class User(GenericAPIView): @@ -54,6 +62,46 @@ def get(self, request, format=None): return self.get_paginated_response(serializer.data) +class TrainingExperimentList(GenericAPIView): + serializer_class = ReadTrainingExperimentSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingExperiment.objects.filter(owner__id=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all training_exps owned by the current user + """ + training_exps = self.get_object(request.user.id) + training_exps = self.paginate_queryset(training_exps) + serializer = ReadTrainingExperimentSerializer(training_exps, many=True) + return self.get_paginated_response(serializer.data) + + +class AggregatorList(GenericAPIView): + serializer_class = AggregatorSerializer + queryset = "" + + def get_object(self, pk): + try: + return Aggregator.objects.filter(owner__id=pk) + except Aggregator.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all aggregators owned by the current user + """ + aggregators = self.get_object(request.user.id) + aggregators = self.paginate_queryset(aggregators) + serializer = AggregatorSerializer(aggregators, many=True) + return self.get_paginated_response(serializer.data) + + class MlCubeList(GenericAPIView): serializer_class = MlCubeSerializer queryset = "" @@ -157,6 +205,50 @@ def get(self, request, format=None): serializer = BenchmarkModelListSerializer(benchmarkmodels, many=True) return self.get_paginated_response(serializer.data) +class DatasetTrainingAssociationList(GenericAPIView): + serializer_class = ExperimentDatasetListSerializer + queryset = "" + + def get_object(self, pk): + try: + # TODO: this retrieves everything (not just latest ones) + return ExperimentDataset.objects.filter( + Q(dataset__owner__id=pk) | Q(training_exp__owner__id=pk) + ) + except ExperimentDataset.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all training dataset associations involving an asset of mine + """ + experiment_datasets = self.get_object(request.user.id) + experiment_datasets = self.paginate_queryset(experiment_datasets) + serializer = ExperimentDatasetListSerializer(experiment_datasets, many=True) + return self.get_paginated_response(serializer.data) + +class AggregatorAssociationList(GenericAPIView): + serializer_class = ExperimentAggregatorListSerializer + queryset = "" + + def get_object(self, pk): + try: + # TODO: this retrieves everything (not just latest ones) + return ExperimentAggregator.objects.filter( + Q(aggregator__owner__id=pk) | Q(training_exp__owner__id=pk) + ) + except ExperimentAggregator.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all aggregator associations involving an asset of mine + """ + experiment_aggs = self.get_object(request.user.id) + experiment_aggs = self.paginate_queryset(experiment_aggs) + serializer = ExperimentAggregatorListSerializer(experiment_aggs, many=True) + return self.get_paginated_response(serializer.data) + class ServerAPIVersion(GenericAPIView): permission_classes = (AllowAny,) From 8f52c499ac38068c2720e734ecf634d5aa824ff5 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 29 Feb 2024 00:03:21 +0100 Subject: [PATCH 002/242] adapt mock_tokens folder --- mock_tokens/generate_tokens.py | 7 ++++--- mock_tokens/tokens.json | 18 ++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mock_tokens/generate_tokens.py b/mock_tokens/generate_tokens.py index 9579b0711..7d964982a 100644 --- a/mock_tokens/generate_tokens.py +++ b/mock_tokens/generate_tokens.py @@ -25,8 +25,9 @@ def token_payload(user): users = [ "testadmin", - "benchmarkowner", - "modelowner", + "testbo", + "testmo", + "testdo", "aggowner", "traincol1", "traincol2", @@ -42,4 +43,4 @@ def token_payload(user): token_payload(user), private_key, algorithm="RS256" ) -json.dump(tokens, open("tokens2.json", "w")) +json.dump(tokens, open("tokens.json", "w")) diff --git a/mock_tokens/tokens.json b/mock_tokens/tokens.json index 532c3e600..096b90a06 100644 --- a/mock_tokens/tokens.json +++ b/mock_tokens/tokens.json @@ -1,12 +1,10 @@ { - "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTY2MDk3NzAsImV4cCI6MTE2OTY2MDk3NzB9.oKS-F8L6Vv70vKklpX2kBiBkkKCh9D-RtxdVzSAiNk9MTDIJ-_7wNu0yBiEnX38IOubnifkh8v5OyAJ85dU5Au1LsekI0YyTI0WLQXKbywP89vfYZlfEIACvPUWJhRbHMJOGn-WVrPuEbGMDuDw677xOm5T04Hol9Qg4rAsNjYt05SnVwM4ico2CH9AR0LSrsdC_QCpLVvym9ewE1CrstmalPWWM3SeBves4qGSlwl1oTXgoOUXgK9DwxLHB-r66XZrcNwXZRuBYSTeqQvDnGP8TG4bXL1gQbvkh2tDbtj-DEyfUXxPxN_GVnlqk6I8BS-A9IoWiKdf0rYatelHd7aWbBgdCg6fxJ4HL4vxChqi3-X6dH2O4vUGkTCR0Td5NDhhe4gfj8WxXD293i0Glu1xOO8DVnu4j6GDfK8WfXtUgHwc4FJHb6iXDJqhAnP2jy_LPSEfjnItKQNvRyu8W3D6LxcumCLO6IvhFkOFDzcjkkyEIdLEiUzDkNeiM5eCqmuaoYW89ARy_GTMKMaLKRWyeGzeNP6celYJtoVzaUOeQvlD7mTFuRupiVC5PxuVFysdTUqU97MSRGSrjqxpghKcEHzVH_mJypWM_Psv4v_6zUhzBV7VKPfGTyareK7Wl6aWPKrGYVm23aAWAjcMgd6otQpjKY6SIkLCIRASuZlE", - "benchmarkowner@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoiYmVuY2htYXJrb3duZXJAZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6ImJlbmNobWFya293bmVyIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNjk2NjA5NzcwLCJleHAiOjExNjk2NjA5NzcwfQ.OPgLiySauDxTtjjAEvcOoklpquVOsu95XBECxDvn1v1h1NW__h7ob21bytCJic34wfF_b2Zy9FOXF8kTKUwmVqwA3lYn9i35R9LegGPyEqNMOQ89ou7xZFJvxUJYm92fz9R0oIh99swyCuAqaze-B5I8B8lH8wUBPlsaV6-EV-O2VOPmACQamLJzmuLhKg5P9cTP5dngEytHK6AFzIiuDYk_7JKfy-DCsxPRnpQV8Ct_dZgwV3RxTtGaFKktttrgTOo7DRPjXw5q-BzLX8RQYL8W5Y6taQ4qPVz5Q32EYKGzvCJtNj-gEFE5p30kfv1DMKyHwe9WnYgGG1EMREnyUWeHNQWG_SUfk07sE_RJpQ5FgvQFT0PGP0vKX16YUy5CqXQflRN45lv9qBFUEii1v4ORDJbYtrgaTEQeOLaL2l_7ucnNQsgeKhkiXLUdWedueUv1Zjd_Y5yL0isG6pUQeDJXtt9FugZMkzrWX7yqrRmqDy4oX0DMobCuLXy858MWgjwd1xEq0vN174OFjW1d2fJ9unNlx_A9ZRAPIV9L3ZtZ91yqlKfkym_8WTzcUyHzeeirQ9Gq6_vVYHIBcALJ2_EQWSWqbZXJtTbUMxT4IIGWZtXaDOMr23HsxtcuMQkjFVFJ0BW0SNpUByc1mIis99RtYHbbSPzpuWHhEKhhfm0", - "modelowner@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoibW9kZWxvd25lckBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoibW9kZWxvd25lciIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTY5NjYwOTc3MCwiZXhwIjoxMTY5NjYwOTc3MH0.YKGELPYlzHjT1M5C83xo6wLSMQ1Cj5O9AixoUUikc4PYXbvaaz6kZ92hc1vD9Oy9LTp85JKMkiwHkkmenkkYrEvkZzc4WTYaKOqGbn6fKpErBnSgszhbPnh6oI3hErSYI2hAI42v0w2H39vY4Kj6dZxo-1grWZG4D_o1xcc9OM4BAr9cD2GEQVrtURTbF2j46gz4uZZHWZZEyRCJDFfpKq9X0EdhR8muFvKQZ99Jfp-omM9vHZ7E4Bj9W4K15xodhKVzDwlFVBK-oXkvPo08-vMlWMQvXxB3dKBkPQMjYUOAstcdi4D6mEv9MfDKwxXIY_dKsxINkReU-6CdSRDO1mmc2SJ3k362Bd8r_Aq3T9P57VvsxyUxdD8RzOfuJk2letHSbhkJ4XuJeARDaF64oygk3-jLNhuZ4LEDMG00BqstIyEH8WlhKGhDU6AK0GrHytSp0NeeMyAXPGJg-OJn29eV1SW2N4UzOQ2Fnjj9klr07zM1U_vc68P4pKIwvnkoWVeNBjrU5sVugqziVX_BnbNHaOYfWbNlIZ4ngkjapr9Xr8WJd5yC2bcp90hgy9cbaLbjEKmF4mxnzE8IEkjXMzGDAb9oWpYSUnU0U3dxaKgdtDB8-mmjlcYpojcN7Iby2GLv8DSTuW9iojxcKw3YLoKONCfEVKI0ssc_R4zb9dI", - "aggowner@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoiYWdnb3duZXJAZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6ImFnZ293bmVyIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNjk2NjA5NzcwLCJleHAiOjExNjk2NjA5NzcwfQ.E2UH220kMwuow7bfKt13DCMAzY6i0YO7AciNqGXl9ApWT6kJiYC_swgvzkfDd509YEEcEeiVS8Ik6xfWCblqXVPkfdYLm1MmzHbnbDODIfeTd9uPAmIrBkDYY3vnQtpcg7NJS24VxO-2xa35YY6A8FsWGD3QqnJd9tuken5RQ4OAZTPaqvsPiYDLXfYwieM3bVjM7o7GsmykOiPN8E0-4qH78COWBJU9izymluJJhEetW8CdCxmJ7PomXQXvWrCoQjCf3J_8i28TABIianjHytdFSpXxqE14IVwn-OU7qj6V6zmjr5gfSmocNVcb3kOrwb_QJv3Cqww7tUJwF0q5_EncCtVB0XtZkDFRCBzQynI8-VR-z4eZ6SpK4RdYsBYHNydGh88RfKvcszCceAh7MGhh1XEjaKgM_IQAjnrXlMv2dRmvuSxooXXVhoZ1lx6tq9Sfvf4FvuOJCAVicPfaiMpoZW0UblEQn6zhugu3hiu6huDaYd5Opx9ZyRaq_fH_SNwnd9GvioqNNgOyatrCgqFSJdMQlMfHv_lXl0tgR60wFpWuTAZSz68bMvivwACunN3C8XwER1i-SHBO6OwoKpXMNW1bOaWPo9niggY1LdQuBeLauKG35Ee7rQZ8sVjKadu58It6GHBEQG_jLJN4puFmSHH6eMhwGQKjaVuOKS8", - "traincol1@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidHJhaW5jb2wxQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDEiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTY2MDk3NzAsImV4cCI6MTE2OTY2MDk3NzB9.lz-yInC8p4ENvIfWghre7B9UuAZKP5yQaUc6V2-0QsgP5fq3HnUpd5zl9DaWLDi0WX7OOk9CbY5aOFKgOZ1g4ulSh0QX382R3QpUMmKbvaZrVVKaaN1AaNmjv_89bKNfqOboz4z3hj_9S5F2I2aNy-LYYKzxa9cbMyHCieaA9-KZdZRVV-tK1co-nxnVU0QlKAf6TQwVEaBfBPWapFwnyqSs9v7M6WixKmyr0zWjHupKObmncpVkKmIh3HjTXEtdxeC4F4-V7xPJBHoK_hbTcnzySzQYmZWzgHiTWK_lM-U9Y3ugcPDRoVnIoL1_tJdFrBtsqgTqYvgGAW_1gwd94eWvwCUzpNBjr910byRPcFHdlZX11vnmhESVdZV62wQeOuwacR07FSWGxYHgAM_UWhySOcoB4qtza14tTb8YrlIZVUMYxoQYlB02RY5ZXvV00RqKJWJWKdQYCWgheeuZulTRC1Y4V9eu8GRZsfnk4omWVmkkcmx_Q3JHuEiGFfZmTZTF9p0m2tVZBWE4ML6EnX7ndJ8scZnsqZGkI8WsEabaPmcjg6elRGiqIMsOj7oigEiYqKnxiWQ3bAlXUbXAMYX_kAb9y3PVoiMsXpT9cFkTbOkFTU0_aYlcUOsIAZLhu8dMw29fJ6r8ZMgblOwo2jsgqTrAUqeBM85YXSz6NAM", - "traincol2@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidHJhaW5jb2wyQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDIiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTY2MDk3NzAsImV4cCI6MTE2OTY2MDk3NzB9.PRs-r2Iyye63fWyplOhGCQhKNHRMCup2upgpIWmh-xXAV1cBT6Wq9ANYq7oUmq03G8Aw53ZX0H8wV67nIhkD_dG8gEV_Hp_36rv8EmPUdlBYLQP0rV059zZP7s594sXtZ1G1hFZvs18XY_GNJuyOpZ-GdWH2nPjAvyGfU8JBYVDjf79HgJbrbfDLBvlRmrrCCA40bO6ScNrTXPnsBuefEuLEGqWsBiVKU8hOoBnPj8NAanxQBjgpph7kkPU7kmaxHn9rJEp3-S_8Ozi3J635roOKgumysDDrwcDt7oPJLrL6SVhHWzmMzBxN-ozdXAI9sJw1H4_bw6CG3MoahNJxwr8kHdFh8GqA58K_aVOAZNeSa_EUJGjquVkidOKyKcfcammHgw6cUmMk2Y1GepqSr4-KRjLrewkC3jxdnCeWgqPoiUrpCc8OcRGXMhyiqvYJUFdnfqEOMvXN5KCo5KeGew20h3zLiJaacDNQDufjLycq7x2DbS5UmtWXgAbGQWKllaVIwrIOAJ-Ev8HKXEXsYcPUqrTel481ptzChZ0Co9HokXGhHf9R0s04zwc3c-jQUqmcMa3nYSuEoInQdgjP9xgP2jXTyfwj-Pf5BZ8Tisurt_PWugcheKWXBL4txVoulfsN1CEDR7sXeWIbSVhRXX15EuASETUMMEzlI5WaXM0", - "testcol@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGNvbEBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGNvbCIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTY5NjYwOTc3MCwiZXhwIjoxMTY5NjYwOTc3MH0.XEj_ft8hxKm1dphiBhivx-SWG2BP39k1cMkArDHviKtp1vlROUUN0GgE_pkNfyn_6Rx9JUCp6LyDbCzCSC0ZlZ47qnuCy-rfVPUNRhKE_UYzFHUg8joaCJ6_3o7gzAoJLoHl6oc5bL3LcQIPkIqCFyoBp3JUaT0Arrv_tqSw6Y8WyQVxfkmjcnRIKxpCvUPW0SfVq605HigyULvfWyrJSepAn_Mw8bfDGEgCE3DfkODI2t0qpGuM-m-0neA99jk13VJdtEYVU2c_rbiGc0W8fBvkwytlqCit2JVDT6SlLgkJHr5WJQkP4Jx5AJ21bdTcWVMJu-xInO8yXSvLG_u0rRPl5oNImYXlO06c0PYCXxTuqqbljm_Knj48YcoBWrgWuxB-geH_LCXPxT7O1pQoosodBPwNHwlMt3C84rcUXeOxPTErcDputY-UNnuKxK0XNGMpsXnJZiQRh3I_4v5FtHE01DY8w9XjrMxLjBnUuISkz-Ct195wG-Od4v7Sw_7ikey0fn6f0EPo4ETmM1q5Oa_1RHV64GWG5NtaDezZbqDYQYCgFxm1coUIbPFwSWv7gT7w15Mj1kYO-RoLmt7DtHmTLu9ilebNhfUV34ZABnt96RFo_bSJY_t2eeWYSwRlnkCDbktmtBKlG5o8LKQVJq6Uki-90H2-jyIG1dp5BRg", - "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.c5dnk-e8ys8MpaRzBkFe_RQTjGLKQlGSkJVNBhiKbzJV6565iHQ1HvFp88q-6ldGIwVWw9ZyY6QH2EWbPYfLF69-KnNcxUbDOJe9jBX0UAfaUsrcsaVoxLJPCojnjqKoIKgu_NM5PlEvsn4ojYA8Q-DxJ7r9RexnhxG0TxU_CwjhYuzV9RQsE3phWbHmcFv1-OIGWr72q1p8QwfaQp42K4iyUC7u4Buk6we9V4NJIJPPjjadkmdsYnm3pJxYVhKZx1pTkGnHTt6YEilQM9Iwgw_1mA5o1AYkwTm-_9lMxbSiwMzdxZmM1S_L-XoVlrWmyxeu2-BdLMIt49LS3fSXXDkcGyLboxLlg8v5rCMkhpUvDvw-dwpVUi1Y6QIVIqQocDJ3Bj1K-5SbGISc2wU7Aa9GNXe7GcEFn9DoeCthy3aLfucc3l8usZagopkAjVSGmClSJNw90VoWA7kER31E7Ehas5mnlJedGeyNV1wm5r2sMJfUsnoVqaCMHNFE4SQJaMVkGQOS4gn-G_8_WMLVrvfT1f_dg-dvISWrvhZ7-Uow-gi9-px2oY3Ehk2b82CPRMf3O20HOppLTg7ETqm77wQq9Elqrn2KoAYgc5Rlr-JZEq76NI37AgYRDIRistLI9_-UzJmY-YYbep2PsiWBx8WrUqm8uDIKDoGhfjBElQo", - "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.FOKlc3ZKH193-cINs9fxIPitsPeE2RZYl_H96xWaKIGOYOec55hR1WgybXD47czGUuTLS8NQGnFcTf-CGABWhFs7f3oTWKxpqo98sOGIhXxASgktfKmqL8tZCv-a4LCUoGFfeekWCvSo_fezkFnxZT1090JpPYpwcbjbTPY4u0JSU4XbaQh4y7N0lJXlmaZuAb13ncur7uY4A6Onl3u49m3tRjbg36r8fyzUwKppMyo5JdLS4h8099ZW7B6v32Xsr4TEz-UOA88YF_mmgK1P9BaFL7aQuARqsqvIQgdmH4JYxnoyng-UiroANy_HeeMJjhDtkuAVUsF6UOG_excgC9jA53GKNCSxxqERD-mxvZ5juWyL86fTSwuaBALWVv5HFUjHbxsA-2JzKt-ZrSD93Q-DqTi36f1WGfwphypc-2W201d4l1EJSg2b4FAqqt2dDDV4deGg3FCkR4vHXdA5aWIihZWmtSLsCylkg9bJiR0iZp6w-6I4R-8KDTG10UgFLPQDS95yLhClcjUjFnmI-xs5rLJvfuIbQU58yoOlBNwVPEm-G7vfTpw7kl_x3qjV_mOw2QSJE0_iQ_gi_YZ1Lzh4ueEgBFYF7VPrLV4x2djJ-Yd1Qb--YQaBRTTFpZ1A8EqUQ-BCn0K9FoWgT2ChZj-B_Oks6z5Z2sqXyauYFU8", - "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.PPfm_uOpz0SuAUEVKE-630mSDuboRUUliSoL5CiE2gUmzxHW5iyJ-37pOfuO5e6y0A4YMnh_xTv3yxgnG5OOB_6ZLD_KF6e5BRJWkZhqKixGLDpQReUv9PadR02eMgtaQBKNuy9Ey8EV8mtyeIUH1JDoIwO4Lt17XHzvYwM7JdCH3TguxPYvCMORoLwKFfhPUlBeMDnjQBnOiw7gv46CdoXHAUqj7k87gp04opgSaeA1tsJIzakZiwXLK4CJvTo_jIhd_w4RNXYHiacI5AtoI8zJKzIdmptkJKcJv2yWq8nFpBlpAr09X-c4haBB4xGeX4743yOBQ6jMdW2f6mqOPTypbEn94_tbn-HaqbZJoclSwByDX0AN3j5KZ0-W-zR3CkPazkCqIWNLO9dte1fz2iocmsAmYdxmvcCN-uAMFI1tRaYHHTd3lyv-GOOllhs-Pc-H5fBfuN-H-l1LejdtdIN9qAxx5BgGVla_ac1mdHEqljO4AuppK4dMcSzw3BUHU4R0uqZ81-a-RoahG8SiPvihjl8foWCaB86qCBJDGItJno2Zdnw4Qpk-EeVSCnQAz0PzVhw9hhhNlAeO8GvKL1-OR1aiXn126empXYM7-wDjs8aAEY4oCo5987GU-pbdVwE8As2FCFI6vbW2B3JBmKideidFpeT92JWqeXv8Q_M" + "testadmin@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.mp0Q4h98q-NcFV6ub4qywiOy5bFAE3pSpH70UZ86MdL0N_U_xpkWZ3VmrF3Asf5Z3LtvsQMcT5S4C9gAdxFhvKXfSXYim2dQEpLMLpcT02aOf3IE1bTzm7G2WXiRvAXPQRlqt1KFfwWMuT1P4bZvB53Si3_slBX6BVPBpll5ZaYavQ8k1gAg-28zcMarDKJBnDbDy7jCMSHeKUQ-LrCV7LlCN1sdbXjhHhx2B13FdaNXHxPZHSPl_grtMZlZNiq6xzRhnVy3nUBCSuGxfzjDHmncF15x23eX0NDlZu1HS1uOsRT_3sonrcRE4r4XKFCjvyEAORynHez4v5BMXXqyf4MRYxYQ7py5OBfUjwABd52N21iIvikQTRF5yRWfd6QxkFMQ4oQ1IQUd8FHnzYX0gbHyxG03HhGdR12Z_EM56pzixgcph4bTIkr2_w0qGRfAvl6gmyFl5KEUhjeO7vZC24dJTej0yr729dcB037VK8hhCYBIKDl0OvVvSFEPXNvxlkhM86ae4N3UEekXhb4ESF-jx3Akxbx4GS3gTENmoHST9PCBQvPZdJqwXcH8lUbHlW4ugQc0qpqYRNy4z4_aRR2Q3OWkpJzkAhmcdYYiQU9ACkbRmggSDIYq_P4HhQFG15QfTcNLJwBwia129LZWzwWa-en9WcIIceQ4g-hmBq4", + "testbo@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.cF0KGfhTNv2CHYJw9QCRnakMwqxli9piLoeliUh9KBvHUTrLFOz2S6ZtvbiaTlhi8sx_TPmDDzNl09Vw8TN1mzXUorvhNj5Al0_nAGjk10XnTaRSwj2ZkA3_vDggXI6zu20_1DI3JR9gcOYq7He_uPqmzRfnFnCw6MSpM_FlvxLkVJzaqIe8Jh1XslXIa3Jr7uWBv_Gigw6DMen4Y7jsXuFPRiQoF86zEziv-2l7VjTx7EovS3loGQjDCJ1LwH053QdlmgmMa_3P6viKpCuvO2jablbDCMAxtg6JJmgbPYXQ8kmrVdRBEtybUY-4o46OFjs9vlkiK2JjUffvoYWQViwZKcgMQ6SAWQth_3HvLVlCghP6Qg-0eaJYR91gkbJ5VMXRccugzPGzDQ0o9yOSenIhj2GCKamn0YHMcFV8t7TtNdy9BuMjLiTSDz_v3nVU_rAdgqvwRXbkTuixdUYqx9h761wqpw_VaQnaVkaCJKhPOC2V8k8gygZsfNZrsZ4j2u5Dd2JmZgLzLCMp1XlXQxAboaAGRDgb6qZrojRaYYMQ59MAO6pdJ3IShEZMcEMKZdsbNVhZcghBHOFAfEuHgo2DC4itTB-VZVZc2W-j1gP6p3PqG3k5mHJ6_PjRDenqTmatDJJYDjxgHYC2nSQkOPD1CSu0XPRMvgReajYgRwM", + "testmo@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.h4ElJzfmsK_LWOYYhMPyLvg0VGtG7IM1JKA1ajZGSPnsKc_1uEtJ87k8Tgg5mIkBA5Fggv0Qyrmc95awAbpaqxCaeCJQAawRx98UCTQ6emmW4w9GE-eryLMVv5EpYyugKypUoQ6bpSg-kTEQ4K_f4i_pHBGAwHPjXR1brNy0krbrL56Q6P2SEJaoynPXMT-qVCIuND8DXb5XBNVUtgXP-VINuG7TOdJyZ50Id_8JcX4-8CO2qSHs6jXy6sgHX6FcCm38E-1xEhNQWT7OX8IX3Tlf3idSWKoV_ZQQX1z5GXGFbOgLxnr8aY2yuhmnTHNIWeMoIHIIw9bODagKMBAC-O-uvyr_ejcMRNahcoukQD2aHFjykUNKjkCSKGMi8dAWhVr3rl0YfFR5CNzo_vA3s2TICHYZvD3_jik07OXnIK04otAHH6mfBypHuA4ThRAGU_twW-FLjEDYfh5Rxm9DuIcsPM7fGr28pbZSIw8nnPYf41fVzBIzAhvcD3mpjqkU6C4vXq-LEd6mBNIGMOsYUm_83Ae1gpBgexu4m4O5DUroILAjclEIG-meF5PHmR3uGOt-oiJ39jXUo73HdU5bGMPzSzC5TNxKZMIdY3yCReNUCKvjr4R_YJbfTKn7aJS2prs_1kQ_0akcICiPmrXAng5nQV32MiqCiq2xu9kxk0Y", + "testdo@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.heenig4R9kb4aIPuQkbmjVUwHUQPL7yMwZi9EZ3YMkpVWKwDb130-BX0ovYpxzkfDZoIlh_fUYIzfEPrqyUt2HvVn93JgJJp8le9JndNREjoq-t01RGyh5ieJBmmTxXPJ7X17HdCcAGnB0DNBFUNk72nW0qCmgCHAypey1QPi8iiINTuKYRHCy9yqC0hyQWkrUzue8LxGxPCD3AdKn7P19yQMTPDF4TRC-gzDSm6AurbIEJjnyGe2YIhI_w-m6qnHg5qlpw6CkaEa-yosD6tmCuxSRZ7-nQna9S2FdQRHl9344r5dzaN4tve8k3ZKsdtz_tfNuwAZAmTUMpExsNkmkDyTlQ6XZ_9u1PD0nEaeusZR9PBPEwEB3JQvC_lOMg2rB-gPlB1PPxNCS8M9WivqFelwsK_ddYhymQrcB6ZqHQe2ITiQKX7ONtqScj3JoNuGvt9nmO_JMC3GJlJSPuEm7ZbuIProTrE7aZpEPztYiyOHRQzGA-qsLM17so-QcRB6K2Z-XL3AndWsAfd1FaRxTtTVR061yTb1rxW35ChoHz4T-ImSXk04wY2SuW6oe1VJdlE494UM1WdPUURjXtM-bOGtZbnolSTLE-3DyUcrYB9I5GCUEDMkVOERXX7z5nsU8TUohUxwe57lreCVWABwLZZBOmf1gSEPejX3HuKMfA", + "aggowner@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoiYWdnb3duZXJAZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6ImFnZ293bmVyIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNzA5MTYxMzM2LCJleHAiOjExNzA5MTYxMzM2fQ.K_OcpQ5gxnJlg-G_-qt_YFITTnR2NqBwcSsRxK-XmF7x8nf3eiZAJ01pDeGRutLMyUfMHbDaGdIds0L30z86Bb2QAOnrh3Zfq25xQpyxg3-hXiRGeMURC5CfjTuibAsWyi1514uC9JpRgWrnrzkI8m_AaiElRW2rSgp_QZ08tDvc9ktSIJzGXACMiWsamWpTSvf4mu0Snp_Xe1K0yPZrUBczeNaOk8Qus-OhfYXHuKd4kiuIynUgaH9SiylQj8-wzb8zi6iW3x1qsfY_qexo1EHc5b-58-NV-0VY0PCmWIIfPQlxustn2h4xr8j80xP-1LnSxvG_0IOXeuiW88IVXy1KzutwiMuRN3qGusoLspRVWqQh5aVLGsRN0VUzJAUpPmUSCjp6VacFh3oRWK2vdM4R8ZcGSwIrlup_5mCq7_hGxQXTu1gZ1e9hW3LYYnuCTFrPPCzqqGSeXkHrb37oN85yW-JdL2vBoGnhNSxZIMAgMQK9Cq4BR1ZdDZpKu4MXkzC_L4ZW3vVlOpgmKJ9oYDrJtTwaPuVZ3ratIlZPNt7Dp-4z0MTATUcvyjxYSxq1qFH-SnA1pnmA6OVEIKgoCn9rjEfg2v_y6ad2fONfg4czVGJ46HFcYmJIZ84Os9ME6yAlciqNR89Qix9NNRvCbfd39Y5K40K_25l01S3JcgA", + "traincol1@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidHJhaW5jb2wxQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDEiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.p6T5fX5xLGZnxyTW1258CpjCFJ_IA3PyR-h7aR0uxQPM1SeVDECfEWU85YGsnKlxvZJItsmNqAgvQE-cCUnC5DOq5Fo6XNEofbGi_LCDOgPxZ0ywNkrQk_H1S3VZqt4ukU2smfh_0biNI3qvWkEQL__VnStRPm9UGFP9ZoKPqbKxL81oM16FaWD7Ahy9ZofgD4v7fAvp8Yvmmw4P6LjIqB0GG7H2t1Tg8qwgJV_i0zZ5yddVUcn08EZr1H4xlkk7FhpFLCyMSQ9NKJvT_DkHXwuzmsxn3DLOr5lfQpJLM7DfJekmFHJwqCtbC1YsDQ5RZCcbjMdpXQyE2WWs6zraGelL1DcRi9chuZSCQD4y2yfhm8eCW7rrrmvUYVP0kIwt0Bj9TvjzBjLmFpMgY23RPV-iI6bcqcU7p7k3-IK6WO291O4iDM_udY5lHqm2wus-57pVBhSRQ4TOvrcwD6K9uX64j9zbGDplhOkhYGp2M9WsWckWazYQt5tvSVj_63cTAjDewQvtv2UldriGCDTIQzGJaHnOqwSK3iXDV2KwGT_m-LEdvGTCpusi9c5J5mFVk8QHoFZbLNKW9Yo6KfxVZbbPf0p2FPjnLaSjcWFPkQOLrQLC8N8NhBR4i5lRSVmc7i8uitwyOYJT52wEUrnEZlgx3uh9Bv0ZZaVDCzK6HDk", + "traincol2@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidHJhaW5jb2wyQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDIiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.q6atpv6-DYSt0-F3a2Pb5OWKF6GjalX1eAASmNv-hn8vQtPGCjYcJ2EFC9FrR2Oc8loiFSPmu4POkLMxdXzsX436_chd8K28xssdrqodexvDrHr5W-3NK4LhDsNZoeA3D9Sg3zIP80q94YNXmCKNmZT4wuRmcH5pdAQoj5DantfKe-Ctb0tZb1p4Qt4XErL6BkY1X7Qw9UsKVTM3T9Aeb3unTV_m6k-jzMHpxqA4VrUS9KmfZS_0bpv4Po-X16umvNTiAxUchXTPHrn2uNhFz-4xxwFoGxGBxnWJXpBNe4JaJK-_yKMRI6qBa8A-t9wXlj7Fa-ygDYglenFkTItmIUZG27WJEglU6Ue4-AjduBwAMAvXNzEZUhKsdEns4yPV9DiGxlJfKqBNHQacx89SweRGqdm5qxPThFxtZw3zNbX1d9-THYVJnAbm8VuvklAI-4PxPghJP-ge2JUskutpuRDuGh-mUSKARc4c1OZdwxurycYkYgJuJ_a5NJbcq6s34TuF4acRh_hoNtNqA_CmxvvG8QLYLtTpO1nLFyMNxNEHVVTeKi5PkdNq2IEGwTR5M1nW-OSF4nwS3Hh9S_X7PvdTaEZTPFcjjvnBwghykLsi_j07xl2J2Ys9QYs_LANtV6dldnPBKKLg5FgF6rgly-oG5p0KWpq1lU4w7Jsq6Fc", + "testcol@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGNvbEBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGNvbCIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTcwOTE2MTMzNiwiZXhwIjoxMTcwOTE2MTMzNn0.RiHDW8kT60yzFgIyoGt3QS6ok6VAzWLEYFe_Tr_qZV-IkWJULic9fAYkajx8fg4a1jLFRaeClGs6M3p0pj4izixZnCZ7v6OZFOuk-qeymtgSgRoo5gU7fzcSCzW34AHIC_5x2FlXY7Kp34r3pBzsWwF2KiNufeywdDP8bxWgz7_hwOf5Ne6BnKj1UtlIWd_u6wIHQ-OHX0bWQM_fqQCBqaP-qxkOCsmwNHqXf3JLs-fsi9BzPBORlBIl6jluU9gCsLBlRZ9CUv5zrTiLZUREVFmKy8C-eQE6TEPKbUxnWvOggYSnXPf8LsI4hLXxki8RdSUdDcG7Qe4bPUs1EAA1TNhHwlphGXnbPWSK6VAFFfsYDHwNZUEH3yYLXvRQcJGmcJOaIvuXOJiQ6kUYkijlNVrqLVFd7xotjCoavftdrj2uizBtcWEZTD1Sd1gM0o7Ok-B2BYzJXfPx6lNZu6k1oXcdTulkrkmFEKPvYG1WjxAu7FxLxoWLwQBUFeG2gtyAWUeh9gJ-KMvm0BOgXalSz0GQY7F5IDGhucKHxGA3k1RBCKEHz2KTZ4x1vle5rYR6jZPPb8iTikz6KaRTBF7WZTGd3XgDXNx-K2hdH3A3CjBYQO8aEQgPnwpKEn9I5VoVhT_QCri_OP4S3N82OpcuciAOoXu5y6BAS816Un2gXho" } \ No newline at end of file From 23301fdc4e9d594fb1fff81497b923f93f41275a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 29 Feb 2024 03:22:06 +0100 Subject: [PATCH 003/242] fix passing port and collaborator env variable --- cli/medperf/commands/aggregator/run.py | 9 +-------- cli/medperf/commands/training/run.py | 8 ++------ cli/medperf/entities/cube.py | 26 +++++++++++++++++++++++--- examples/fl/project/Dockerfile-GPU | 3 +++ 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index d81b88ce5..6a54e067e 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -85,13 +85,6 @@ def __get_cube(self, uid: int, name: str) -> Cube: def run_experiment(self): task = "start_aggregator" - port = self.aggregator.port - # TODO: this overwrites existing cpu and gpu args - string_params = { - "-Pdocker.cpu_args": f"-p {port}:{port}", - "-Pdocker.gpu_args": f"-p {port}:{port}", - } - # just for now create some output folders (TODO) out_logs = os.path.join(self.training_exp.path, "logs") out_weights = os.path.join(self.training_exp.path, "weights") @@ -108,4 +101,4 @@ def run_experiment(self): } self.ui.text = "Running Aggregator" - self.cube.run(task=task, string_params=string_params, **params) + self.cube.run(task=task, port=self.aggregator.port, **params) diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py index 31c030189..dbaed964e 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/training/run.py @@ -89,11 +89,7 @@ def __get_cube(self, uid: int, name: str) -> Cube: def run_experiment(self): task = "train" dataset_cn = get_dataset_common_name("", self.dataset.id, self.training_exp.id) - # TODO: this overwrites existing env args - # TODO: CUDA_VISIBLE_DEVICES should be in dockerfile maybe - string_params = { - "-Pdocker.env_args": f'-e COLLABORATOR_CN={dataset_cn} -e CUDA_VISIBLE_DEVICES="0"', - } + env_dict = {"COLLABORATOR_CN": dataset_cn} # just for now create some output folders (TODO) out_logs = os.path.join(self.training_exp.path, "data_logs") @@ -108,4 +104,4 @@ def run_experiment(self): "output_logs": out_logs, } self.ui.text = "Training" - self.cube.run(task=task, string_params=string_params, **params) + self.cube.run(task=task, env_dict=env_dict, **params) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index b748e17f4..3e60185e7 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -302,6 +302,8 @@ def run( string_params: Dict[str, str] = {}, timeout: int = None, read_protected_input: bool = True, + port=None, + env_dict: dict = {}, **kwargs, ): """Executes a given task on the cube instance @@ -314,6 +316,7 @@ def run( read_protected_input (bool, optional): Wether to disable write permissions on input volumes. Defaults to True. kwargs (dict): additional arguments that are passed directly to the mlcube command """ + # TODO: refactor this function. Move things to MLCube if possible kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" cmd += f" --mlcube={self.cube_path} --task={task} --platform={config.platform}" @@ -328,6 +331,13 @@ def run( cmd = " ".join([cmd, cmd_arg]) container_loglevel = config.container_loglevel + if container_loglevel: + env_dict["MEDPERF_LOGLEVEL"] = container_loglevel.upper() + + env_args_string = "" + for key, val in env_dict.items(): + env_args_string += f"--env {key}={val} " + env_args_string = env_args_string.strip() # TODO: we should override run args instead of what we are doing below # we shouldn't allow arbitrary run args unless our client allows it @@ -337,16 +347,27 @@ def run( gpu_args = self.get_config("docker.gpu_args") or "" cpu_args = " ".join([cpu_args, "-u $(id -u):$(id -g)"]).strip() gpu_args = " ".join([gpu_args, "-u $(id -u):$(id -g)"]).strip() + if port is not None: + cpu_args += f" -p {port}:{port}" + gpu_args += f" -p {port}:{port}" cmd += f' -Pdocker.cpu_args="{cpu_args}"' cmd += f' -Pdocker.gpu_args="{gpu_args}"' - if container_loglevel: - cmd += f' -Pdocker.env_args="-e MEDPERF_LOGLEVEL={container_loglevel.upper()}"' + cmd += f' -Pdocker.env_args="-e "' + env_args = self.get_config("docker.env_args") or "" + env_args = " ".join([env_args, env_args_string]).strip() + cmd += f' -Pdocker.env_args="{env_args}"' + + + elif config.platform == "singularity": # use -e to discard host env vars, -C to isolate the container (see singularity run --help) run_args = self.get_config("singularity.run_args") or "" run_args = " ".join([run_args, "-eC"]).strip() + run_args += " " + env_args_string cmd += f' -Psingularity.run_args="{run_args}"' + # TODO: check if ports are already exposed. Think if this is OK + # TODO: check if --env works # set image name in case of running docker image with singularity # Assuming we only accept mlcube.yamls with either singularity or docker sections @@ -356,7 +377,6 @@ def run( cmd += ( f' -Psingularity.image="{self._converted_singularity_image_name}"' ) - # TODO: pass logging env for singularity also there else: raise InvalidArgumentError("Unsupported platform") diff --git a/examples/fl/project/Dockerfile-GPU b/examples/fl/project/Dockerfile-GPU index ecdf14622..de67e0373 100644 --- a/examples/fl/project/Dockerfile-GPU +++ b/examples/fl/project/Dockerfile-GPU @@ -3,6 +3,9 @@ FROM local/openfl:local ENV GANDLF_VERSION 60c9d28aa5e1b951e75ed5646ac20d5790fe4317 ENV FL_WORKSPACE /mlcube_project/fl_workspace ENV LANG C.UTF-8 +ENV CUDA_VISIBLE_DEVICES="0" +# TODO: combine docker images (cpu and gpu) +# TODO: make necessary changes since now user is not root # install software requirements needed by GaNDLF RUN apt-get update && apt-get upgrade -y && apt-get install -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ From b04ee0ce68f5dbae4049283ecb9beae688adf187 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 29 Feb 2024 03:34:04 +0100 Subject: [PATCH 004/242] fixes related to cube.download refactoring --- cli/medperf/commands/aggregator/run.py | 1 + cli/medperf/commands/training/run.py | 1 + cli/medperf/commands/training/submit.py | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index 6a54e067e..0486cdc7f 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -80,6 +80,7 @@ def prepare_cube(self): def __get_cube(self, uid: int, name: str) -> Cube: self.ui.text = f"Retrieving {name} cube" cube = Cube.get(uid) + cube.download_run_files() self.ui.print(f"> {name} cube download complete") return cube diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py index dbaed964e..9535c4acd 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/training/run.py @@ -83,6 +83,7 @@ def prepare_cube(self): def __get_cube(self, uid: int, name: str) -> Cube: self.ui.text = f"Retrieving {name} cube" cube = Cube.get(uid) + cube.download_run_files() self.ui.print(f"> {name} cube download complete") return cube diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py index 89fc82043..5a662b450 100644 --- a/cli/medperf/commands/training/submit.py +++ b/cli/medperf/commands/training/submit.py @@ -41,7 +41,9 @@ def __init__(self, training_exp_info: dict): def get_mlcube(self): mlcube_id = self.training_exp.fl_mlcube - Cube.get(mlcube_id) + cube = Cube.get(mlcube_id) + # TODO: do we want to download run files? + cube.download_run_files() def submit(self): updated_body = self.training_exp.upload() From a147df83239cbd1e911e7dbd033dcce34270fc9e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 29 Feb 2024 14:09:05 +0100 Subject: [PATCH 005/242] remove TODOs from version control --- TODO | 47 ----------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 TODO diff --git a/TODO b/TODO deleted file mode 100644 index 3f810f427..000000000 --- a/TODO +++ /dev/null @@ -1,47 +0,0 @@ -# TODO: remove me from the repo - -FOR TUTORIAL - -- stream logs -- check benchmark execution mlcube training exp ID -- if association request failed for some reason, delete private key (or at least check if rerunning the request will simply overwrite the key) -- define output folders in medperf storage (logs for both, weights for agg) -- adding email to CN currently could be challenging. THINK - - ASSUMPTION: emails are not changed after signup - -- We now have demo data url and hash in training exp (dummy) that we don't use. - - what to say about this in miccai (I think no worries; it's hidden now) -- rethink/review about the following serializers and if necessary use atomic transactions - - association creation (dataset-training, agg-training) - - association approval (dataset-training, agg-training) - - training experiment creation (creating keypair); this could move to approval -- public/private keys uniqueness constraint while blank; check django docs on how -- fix bug about association list; /home/hasan/work/openfl_ws/medperf-private/server/utils/views.py -- pull latest medperf main -- test agg and training exp owner being same user - - basically, test the tutorial steps EXACTLY - -AFTER TUTORIAL - -- FOLLOWUP: collaborators doesn't use tensorboard logs. -- FOLLOWUP: show csr hash on approval is not necessary since now CSRs are transported securely -- test remote aggregator -- make network config better structured (URL to file? no, could be annoying.) -- move key generation after admin approval of training experiments. -- when the training experiment owner wants to "lock" the experiment - - ask for confirmation? it's an easy command and after execution there is no going back; a mess if unintended. -- secretstorage gcloud - -NOT SURE - -- consider if we want to enable restarts and epochs/"fresh restarts" for training exps (it's hard) -- mlcube for agg alone - -LATER / FUTURE INVESTIGATIONS - -- root key thing. -- limit network access (for now we can rely on the review of the experiment owner) -- compatibility tests -- rethink if keys are always needed (just for exps where they on't need a custom cert) -- server side verification of CSRs (check common names) - - later: the whole design might be changed From b9a91d9f9c57e8f6dc8a49b986b22bc2f22f7282 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 29 Feb 2024 14:28:07 +0100 Subject: [PATCH 006/242] update code according storage related changes --- cli/medperf/commands/aggregator/run.py | 4 +--- cli/medperf/commands/training/run.py | 5 ++--- cli/medperf/entities/aggregator.py | 7 +++---- cli/medperf/entities/training_exp.py | 9 ++++----- cli/medperf/utils.py | 8 ++++---- 5 files changed, 14 insertions(+), 19 deletions(-) diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index 0486cdc7f..a38e99155 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -4,7 +4,6 @@ from medperf.entities.training_exp import TrainingExp from medperf.entities.aggregator import Aggregator from medperf.entities.cube import Cube -from medperf.utils import storage_path class StartAggregator: @@ -61,12 +60,11 @@ def prepare_agg_cert(self): ) cert = association["certificate"] cert_folder = os.path.join( - config.training_exps_storage, + config.training_folder, str(self.training_exp.id), config.agg_cert_folder, str(self.aggregator.id), ) - cert_folder = storage_path(cert_folder) os.makedirs(cert_folder, exist_ok=True) cert_file = os.path.join(cert_folder, "cert.crt") with open(cert_file, "w") as f: diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py index 9535c4acd..d0648726c 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/training/run.py @@ -5,7 +5,7 @@ from medperf.entities.dataset import Dataset from medperf.entities.cube import Cube from medperf.entities.aggregator import Aggregator -from medperf.utils import storage_path, get_dataset_common_name +from medperf.utils import get_dataset_common_name class TrainingExecution: @@ -59,12 +59,11 @@ def prepare_data_cert(self): ) cert = association["certificate"] cert_folder = os.path.join( - config.training_exps_storage, + config.training_folder, str(self.training_exp.id), config.data_cert_folder, str(self.dataset.id), ) - cert_folder = storage_path(cert_folder) os.makedirs(cert_folder, exist_ok=True) cert_file = os.path.join(cert_folder, "cert.crt") with open(cert_file, "w") as f: diff --git a/cli/medperf/entities/aggregator.py b/cli/medperf/entities/aggregator.py index 0127cb410..c011b0865 100644 --- a/cli/medperf/entities/aggregator.py +++ b/cli/medperf/entities/aggregator.py @@ -4,7 +4,6 @@ import hashlib from typing import List, Optional, Union -from medperf.utils import storage_path from medperf.entities.interface import Entity, Uploadable from medperf.entities.schemas import MedperfSchema from medperf.exceptions import ( @@ -40,7 +39,7 @@ def __init__(self, *args, **kwargs): self.port = self.server_config["port"] self.generated_uid = self.__generate_uid() - path = storage_path(config.aggregator_storage) + path = config.aggregators_folder if self.id: path = os.path.join(path, str(self.id)) else: @@ -113,7 +112,7 @@ def __remote_prefilter(cls, filters: dict) -> callable: @classmethod def __local_all(cls) -> List["Aggregator"]: aggs = [] - aggregator_storage = storage_path(config.aggregator_storage) + aggregator_storage = config.aggregators_folder try: uids = next(os.walk(aggregator_storage))[1] except StopIteration: @@ -211,7 +210,7 @@ def upload(self): @classmethod def __get_local_dict(cls, aggregator_uid): aggregator_path = os.path.join( - storage_path(config.aggregator_storage), str(aggregator_uid) + config.aggregators_folder, str(aggregator_uid) ) regfile = os.path.join(aggregator_path, config.reg_file) if not os.path.exists(regfile): diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index 46764845b..6d4e3ee04 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -7,7 +7,7 @@ import medperf.config as config from medperf.entities.interface import Entity, Uploadable -from medperf.utils import get_dataset_common_name, storage_path +from medperf.utils import get_dataset_common_name from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data, read_user_account @@ -55,7 +55,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = self.name - path = storage_path(config.training_exps_storage) + path = config.training_folder if self.id: path = os.path.join(path, str(self.id)) else: @@ -129,7 +129,7 @@ def __remote_prefilter(cls, filters: dict) -> callable: @classmethod def __local_all(cls) -> List["TrainingExp"]: training_exps = [] - training_exps_storage = storage_path(config.training_exps_storage) + training_exps_storage = config.training_folder try: uids = next(os.walk(training_exps_storage))[1] except StopIteration: @@ -216,8 +216,7 @@ def __get_local_dict(cls, training_exp_uid) -> dict: dict: information of the training_exp """ logging.info(f"Retrieving training_exp {training_exp_uid} from local storage") - storage = storage_path(config.training_exps_storage) - training_exp_storage = os.path.join(storage, str(training_exp_uid)) + training_exp_storage = os.path.join(config.training_folder, str(training_exp_uid)) training_exp_file = os.path.join( training_exp_storage, config.training_exps_filename ) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index fc5076823..920160d94 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -515,6 +515,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.proc.wait() # Return False to propagate exceptions, if any return False + + def get_dataset_common_name(email, dataset_id, exp_id): return f"{email}_d{dataset_id}_e{exp_id}".lower() @@ -525,12 +527,11 @@ def generate_data_csr(email, data_uid, training_exp_id): # store private key target_folder = os.path.join( - config.training_exps_storage, + config.training_folder, str(training_exp_id), config.data_cert_folder, str(data_uid), ) - target_folder = storage_path(target_folder) os.makedirs(target_folder, exist_ok=True) target_path = os.path.join(target_folder, "key.key") write_key(private_key, target_path) @@ -546,12 +547,11 @@ def generate_agg_csr(training_exp_id, agg_address, agg_id): # store private key target_folder = os.path.join( - config.training_exps_storage, + config.training_folder, str(training_exp_id), config.agg_cert_folder, str(agg_id), ) - target_folder = storage_path(target_folder) os.makedirs(target_folder, exist_ok=True) target_path = os.path.join(target_folder, "key.key") write_key(private_key, target_path) From ebd46d4e962de33e326f49473eeb8b87431c1335 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Sat, 2 Mar 2024 01:56:55 +0100 Subject: [PATCH 007/242] minor server fixes --- server/training/serializers.py | 1 + server/training/views.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/server/training/serializers.py b/server/training/serializers.py index 42d96379c..4c7f5e9be 100644 --- a/server/training/serializers.py +++ b/server/training/serializers.py @@ -46,6 +46,7 @@ class Meta: exclude = ["private_key"] def update(self, instance, validated_data): + # TODO: seems buggy if ( instance.approval_status != "PENDING" and "approval_status" in validated_data diff --git a/server/training/views.py b/server/training/views.py index 72ebc6828..3cfa1cd40 100644 --- a/server/training/views.py +++ b/server/training/views.py @@ -50,6 +50,8 @@ class TrainingExperimentDetail(GenericAPIView): def get_permissions(self): if self.request.method == "PUT": self.permission_classes = [IsAdmin | IsExpOwner] + if "approval_status" in self.request.data: + self.permission_classes = [IsAdmin] return super(self.__class__, self).get_permissions() def get_object(self, pk): From 218315a28d7d2c23c6a50a5ed60f31fc9aee2ab4 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Sat, 2 Mar 2024 01:57:58 +0100 Subject: [PATCH 008/242] fix config storage --- cli/medperf/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cli/medperf/config.py b/cli/medperf/config.py index af2b81df9..1c7b0eac9 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -120,7 +120,7 @@ }, "aggregators_folder": { "base": default_base_storage, - "name": aggregators_folder + "name": aggregators_folder, }, } @@ -130,6 +130,8 @@ "logs_folder", "tmp_folder", "demo_datasets_folder", + "training_folder", + "aggregators_folder", ] server_folders = [ "benchmarks_folder", @@ -201,7 +203,7 @@ "platform", "gpus", "cleanup", - "container_loglevel" + "container_loglevel", ] configurable_parameters = inline_parameters + [ "server", From eb74c0476f81d664f622a57e04f86171efd3d536 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 5 Mar 2024 12:41:59 +0100 Subject: [PATCH 009/242] fix typo --- cli/medperf/entities/cube.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 3e60185e7..a95a9b653 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -353,13 +353,10 @@ def run( cmd += f' -Pdocker.cpu_args="{cpu_args}"' cmd += f' -Pdocker.gpu_args="{gpu_args}"' - cmd += f' -Pdocker.env_args="-e "' env_args = self.get_config("docker.env_args") or "" env_args = " ".join([env_args, env_args_string]).strip() cmd += f' -Pdocker.env_args="{env_args}"' - - elif config.platform == "singularity": # use -e to discard host env vars, -C to isolate the container (see singularity run --help) run_args = self.get_config("singularity.run_args") or "" From dfe40840b8c82284489895aaab0bfd3e32b3e73e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 5 Mar 2024 17:05:22 +0100 Subject: [PATCH 010/242] tmp fix mlcube issue --- cli/medperf/entities/cube.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index bd495766a..4bb624d9e 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -354,10 +354,10 @@ def run( gpu_args += f" -p {port}:{port}" cmd += f' -Pdocker.cpu_args="{cpu_args}"' cmd += f' -Pdocker.gpu_args="{gpu_args}"' - - env_args = self.get_config("docker.env_args") or "" - env_args = " ".join([env_args, env_args_string]).strip() - cmd += f' -Pdocker.env_args="{env_args}"' + if env_args_string: # TODO: why MLCube UI is so brittle? + env_args = self.get_config("docker.env_args") or "" + env_args = " ".join([env_args, env_args_string]).strip() + cmd += f' -Pdocker.env_args="{env_args}"' elif config.platform == "singularity": # use -e to discard host env vars, -C to isolate the container (see singularity run --help) From 417424acf55e8251085f09bddff4cb84af775005 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 6 Mar 2024 04:12:10 +0100 Subject: [PATCH 011/242] add FL integration test --- .github/workflows/train-ci.yml | 41 +++ cli/cli_tests_training.sh | 439 ++++++++++++++++++++++++++++++ cli/tests_setup.sh | 16 +- mock_tokens/generate_tokens.py | 11 +- mock_tokens/tokens.json | 14 +- server/testing_medperf.sh | 69 ----- server/testing_miccai.sh | 153 ----------- server/testing_miccai_shortcut.sh | 7 - 8 files changed, 501 insertions(+), 249 deletions(-) create mode 100644 .github/workflows/train-ci.yml create mode 100644 cli/cli_tests_training.sh delete mode 100644 server/testing_medperf.sh delete mode 100644 server/testing_miccai.sh delete mode 100644 server/testing_miccai_shortcut.sh diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml new file mode 100644 index 000000000..473e7460b --- /dev/null +++ b/.github/workflows/train-ci.yml @@ -0,0 +1,41 @@ +name: FL Integration workflow + +on: pull_request + +jobs: + setup: + name: fl-integration-test + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install dependencies + working-directory: . + run: | + python -m pip install --upgrade pip + pip install -e cli/ + pip install -r cli/test-requirements.txt + pip install -r server/requirements.txt + pip install -r server/test-requirements.txt + + - name: Set server environment vars + working-directory: ./server + run: cp .env.local.local-auth .env + + - name: Run django server in background with generated certs + working-directory: ./server + run: sh setup-dev-server.sh & sleep 6 + + - name: Run server integration tests + working-directory: ./server + run: python seed.py --cert cert.crt + + - name: Run client integration tests + working-directory: . + run: sh cli/cli_tests_training.sh -f \ No newline at end of file diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh new file mode 100644 index 000000000..491174d4c --- /dev/null +++ b/cli/cli_tests_training.sh @@ -0,0 +1,439 @@ +# import setup +. "$(dirname $(realpath "$0"))/tests_setup.sh" + +########################################################## +################### Start Testing ######################## +########################################################## + + +########################################################## +echo "==========================================" +echo "Creating test profiles for each user" +echo "==========================================" +medperf profile activate local +checkFailed "local profile creation failed" + +medperf profile create -n testmodel +checkFailed "testmodel profile creation failed" +medperf profile create -n testagg +checkFailed "testagg profile creation failed" +medperf profile create -n testdata1 +checkFailed "testdata1 profile creation failed" +medperf profile create -n testdata2 +checkFailed "testdata2 profile creation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Retrieving mock datasets" +echo "=====================================" +echo "downloading files to $DIRECTORY" + +wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/testfl/data/col1.tar.gz +tar -xf $DIRECTORY/col1.tar.gz -C $DIRECTORY +wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/testfl/data/col2.tar.gz +tar -xf $DIRECTORY/col2.tar.gz -C $DIRECTORY +wget -P $DIRECTORY https://storage.googleapis.com/medperf-storage/testfl/data/test.tar.gz +tar -xf $DIRECTORY/test.tar.gz -C $DIRECTORY +rm $DIRECTORY/col1.tar.gz +rm $DIRECTORY/col2.tar.gz +rm $DIRECTORY/test.tar.gz + +########################################################## + +echo "\n" + +########################################################## +echo "==========================================" +echo "Login each user" +echo "==========================================" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" + +medperf auth login -e $MODELOWNER +checkFailed "testmodel login failed" + +medperf profile activate testagg +checkFailed "testagg profile activation failed" + +medperf auth login -e $AGGOWNER +checkFailed "testagg login failed" + +medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" + +medperf auth login -e $DATAOWNER +checkFailed "testdata1 login failed" + +medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" + +medperf auth login -e $DATAOWNER2 +checkFailed "testdata2 login failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate modelowner profile" +echo "=====================================" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Submit cubes" +echo "=====================================" + +medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE +checkFailed "Train prep submission failed" +PREP_UID=$(medperf mlcube ls | grep trainprep | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -p $TRAIN_PARAMS -a $TRAIN_WEIGHTS +checkFailed "traincube submission failed" +TRAINCUBE_UID=$(medperf mlcube ls | grep traincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Submit Training Experiment" +echo "=====================================" +medperf training submit -n trainexp -d trainexp -p $PREP_UID -m $TRAINCUBE_UID +checkFailed "Training exp submission failed" +TRAINING_UID=$(medperf training ls | grep trainexp | tail -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +# Approve benchmark +ADMIN_TOKEN=$(jq -r --arg ADMIN $ADMIN '.[$ADMIN]' $MOCK_TOKENS_FILE) +checkFailed "Retrieving admin token failed" +curl -sk -X PUT $SERVER_URL$VERSION_PREFIX/training/$TRAINING_UID/ -d '{"approval_status": "APPROVED"}' -H 'Content-Type: application/json' -H "Authorization: Bearer $ADMIN_TOKEN" +checkFailed "training exp approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running aggregator submission step" +echo "=====================================" +HOSTNAME=$(hostname -A | cut -d " " -f 1) +medperf aggregator submit -n aggreg -a $HOSTNAME -p 50273 +checkFailed "aggregator submission step failed" +AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running aggregator association step" +echo "=====================================" +medperf aggregator associate -a $AGG_UID -t $TRAINING_UID -y +checkFailed "aggregator association step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner profile" +echo "=====================================" +medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 preparation step" +echo "=====================================" +medperf dataset create -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" +checkFailed "Data1 preparation step failed" +DSET_1_GENUID=$(medperf dataset ls | grep col1 | tr -s ' ' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 submission step" +echo "=====================================" +medperf dataset submit -d $DSET_1_GENUID -y +checkFailed "Data1 submission step failed" +DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data1 association step" +echo "=====================================" +medperf dataset associate -d $DSET_1_UID -t $TRAINING_UID -y +checkFailed "Data1 association step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner2 profile" +echo "=====================================" +medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 preparation step" +echo "=====================================" +medperf dataset create -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" +checkFailed "Data2 preparation step failed" +DSET_2_GENUID=$(medperf dataset ls | grep col2 | tr -s ' ' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 submission step" +echo "=====================================" +medperf dataset submit -d $DSET_2_GENUID -y +checkFailed "Data2 submission step failed" +DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Running data2 association step" +echo "=====================================" +medperf dataset associate -d $DSET_2_UID -t $TRAINING_UID -y +checkFailed "Data2 association step failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate modelowner profile" +echo "=====================================" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Approve aggregator association" +echo "=====================================" +medperf association approve -t $TRAINING_UID -a $AGG_UID +checkFailed "agg association approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Approve data1 association" +echo "=====================================" +medperf association approve -t $TRAINING_UID -d $DSET_1_UID +checkFailed "data1 association approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Approve data2 association" +echo "=====================================" +medperf association approve -t $TRAINING_UID -d $DSET_2_UID +checkFailed "data2 association approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Lock experiment" +echo "=====================================" +medperf training lock -t $TRAINING_UID +checkFailed "locking experiment failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting aggregator" +echo "=====================================" +RUNCOMMAND="medperf aggregator start -a $AGG_UID -t $TRAINING_UID" +nohup $RUNCOMMAND < /dev/null &>agg.log & + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +AGG_PID=$(ps -ef | grep $RUNCOMMAND | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) +# Check if the command is still running. +if ! kill -0 "$AGG_PID" &> /dev/null; +then + checkFailed "agg doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner profile" +echo "=====================================" +medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting training with data1" +echo "=====================================" +RUNCOMMAND="medperf training run -d $DSET_1_UID -t $TRAINING_UID" +nohup $RUNCOMMAND < /dev/null &>col1.log & + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +DATA1_PID=$(ps -ef | grep $RUNCOMMAND | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) +# Check if the command is still running. +if ! kill -0 "$DATA1_PID" &> /dev/null; +then + checkFailed "data1 training doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner2 profile" +echo "=====================================" +medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting training with data2" +echo "=====================================" +medperf training run -d $DSET_2_UID -t $TRAINING_UID +checkFailed "data2 training failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Waiting for other prcocesses to exit successfully" +echo "=====================================" +# NOTE: on systems with small process ID table or very short-lived processes, +# there is a probability that PIDs are reused and hence the +# code below may be inaccurate. Perhaps grep processes according to command +# string is the most efficient way to reduce that probability further. +# Followup NOTE: not sure, but the "wait" command may fail if it is waiting for +# a process that is not a child of the current shell +wait $DATA1_PID +checkFailed "data1 training didn't exit successfully" +wait $AGG_PID +checkFailed "aggregator didn't exit successfully" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Logout users" +echo "=====================================" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" + +medperf auth logout +checkFailed "logout failed" + +medperf profile activate testagg +checkFailed "testagg profile activation failed" + +medperf auth logout +checkFailed "logout failed" + +medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" + +medperf auth logout +checkFailed "logout failed" + +medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" + +medperf auth logout +checkFailed "logout failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Delete test profiles" +echo "=====================================" +medperf profile activate default +checkFailed "default profile activation failed" + +medperf profile delete testmodel +checkFailed "Profile deletion failed" + +medperf profile delete testagg +checkFailed "Profile deletion failed" + +medperf profile delete testdata1 +checkFailed "Profile deletion failed" + +medperf profile delete testdata2 +checkFailed "Profile deletion failed" +########################################################## + +if ${CLEANUP}; then + clean +fi diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index 581c27fc3..db6e23b58 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -40,8 +40,12 @@ clean(){ medperf profile delete testdata } checkFailed(){ - if [ "$?" -ne "0" ]; then - if [ "$?" -eq 124 ]; then + EXITSTATUS="$?" + if [ -n "$2" ]; then + EXITSTATUS="1" + fi + if [ $EXITSTATUS -ne "0" ]; then + if [ $EXITSTATUS -eq 124 ]; then echo "Process timed out" fi echo $1 @@ -72,6 +76,7 @@ DEMO_URL="${ASSETS_URL}/assets/datasets/demo_dset1.tar.gz" # prep cubes PREP_MLCUBE="$ASSETS_URL/prep-sep/mlcube/mlcube.yaml" PREP_PARAMS="$ASSETS_URL/prep-sep/mlcube/workspace/parameters.yaml" +PREP_TRAINING_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube_prep.yaml" # model cubes FAILING_MODEL_MLCUBE="$ASSETS_URL/model-bug/mlcube/mlcube.yaml" # doesn't fail with association @@ -93,8 +98,15 @@ MODEL_LOG_DEBUG_PARAMS="$ASSETS_URL/model-debug-logging/mlcube/workspace/paramet METRIC_MLCUBE="$ASSETS_URL/metrics/mlcube/mlcube.yaml" METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" +# FL cubes +TRAIN_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube-cpu.yaml?v=2" +TRAIN_PARAMS="https://storage.googleapis.com/medperf-storage/testfl/parameters-miccai.yaml" +TRAIN_WEIGHTS="https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz" + # test users credentials MODELOWNER="testmo@example.com" DATAOWNER="testdo@example.com" BENCHMARKOWNER="testbo@example.com" ADMIN="testadmin@example.com" +DATAOWNER2="testdo2@example.com" +AGGOWNER="testao@example.com" \ No newline at end of file diff --git a/mock_tokens/generate_tokens.py b/mock_tokens/generate_tokens.py index 7d964982a..c4b6420b3 100644 --- a/mock_tokens/generate_tokens.py +++ b/mock_tokens/generate_tokens.py @@ -23,16 +23,7 @@ def token_payload(user): } -users = [ - "testadmin", - "testbo", - "testmo", - "testdo", - "aggowner", - "traincol1", - "traincol2", - "testcol", -] +users = ["testadmin", "testbo", "testmo", "testdo", "testdo2", "testao"] tokens = {} # Use headers when verifying tokens using json web keys diff --git a/mock_tokens/tokens.json b/mock_tokens/tokens.json index 096b90a06..f4063d194 100644 --- a/mock_tokens/tokens.json +++ b/mock_tokens/tokens.json @@ -1,10 +1,8 @@ { - "testadmin@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.mp0Q4h98q-NcFV6ub4qywiOy5bFAE3pSpH70UZ86MdL0N_U_xpkWZ3VmrF3Asf5Z3LtvsQMcT5S4C9gAdxFhvKXfSXYim2dQEpLMLpcT02aOf3IE1bTzm7G2WXiRvAXPQRlqt1KFfwWMuT1P4bZvB53Si3_slBX6BVPBpll5ZaYavQ8k1gAg-28zcMarDKJBnDbDy7jCMSHeKUQ-LrCV7LlCN1sdbXjhHhx2B13FdaNXHxPZHSPl_grtMZlZNiq6xzRhnVy3nUBCSuGxfzjDHmncF15x23eX0NDlZu1HS1uOsRT_3sonrcRE4r4XKFCjvyEAORynHez4v5BMXXqyf4MRYxYQ7py5OBfUjwABd52N21iIvikQTRF5yRWfd6QxkFMQ4oQ1IQUd8FHnzYX0gbHyxG03HhGdR12Z_EM56pzixgcph4bTIkr2_w0qGRfAvl6gmyFl5KEUhjeO7vZC24dJTej0yr729dcB037VK8hhCYBIKDl0OvVvSFEPXNvxlkhM86ae4N3UEekXhb4ESF-jx3Akxbx4GS3gTENmoHST9PCBQvPZdJqwXcH8lUbHlW4ugQc0qpqYRNy4z4_aRR2Q3OWkpJzkAhmcdYYiQU9ACkbRmggSDIYq_P4HhQFG15QfTcNLJwBwia129LZWzwWa-en9WcIIceQ4g-hmBq4", - "testbo@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.cF0KGfhTNv2CHYJw9QCRnakMwqxli9piLoeliUh9KBvHUTrLFOz2S6ZtvbiaTlhi8sx_TPmDDzNl09Vw8TN1mzXUorvhNj5Al0_nAGjk10XnTaRSwj2ZkA3_vDggXI6zu20_1DI3JR9gcOYq7He_uPqmzRfnFnCw6MSpM_FlvxLkVJzaqIe8Jh1XslXIa3Jr7uWBv_Gigw6DMen4Y7jsXuFPRiQoF86zEziv-2l7VjTx7EovS3loGQjDCJ1LwH053QdlmgmMa_3P6viKpCuvO2jablbDCMAxtg6JJmgbPYXQ8kmrVdRBEtybUY-4o46OFjs9vlkiK2JjUffvoYWQViwZKcgMQ6SAWQth_3HvLVlCghP6Qg-0eaJYR91gkbJ5VMXRccugzPGzDQ0o9yOSenIhj2GCKamn0YHMcFV8t7TtNdy9BuMjLiTSDz_v3nVU_rAdgqvwRXbkTuixdUYqx9h761wqpw_VaQnaVkaCJKhPOC2V8k8gygZsfNZrsZ4j2u5Dd2JmZgLzLCMp1XlXQxAboaAGRDgb6qZrojRaYYMQ59MAO6pdJ3IShEZMcEMKZdsbNVhZcghBHOFAfEuHgo2DC4itTB-VZVZc2W-j1gP6p3PqG3k5mHJ6_PjRDenqTmatDJJYDjxgHYC2nSQkOPD1CSu0XPRMvgReajYgRwM", - "testmo@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.h4ElJzfmsK_LWOYYhMPyLvg0VGtG7IM1JKA1ajZGSPnsKc_1uEtJ87k8Tgg5mIkBA5Fggv0Qyrmc95awAbpaqxCaeCJQAawRx98UCTQ6emmW4w9GE-eryLMVv5EpYyugKypUoQ6bpSg-kTEQ4K_f4i_pHBGAwHPjXR1brNy0krbrL56Q6P2SEJaoynPXMT-qVCIuND8DXb5XBNVUtgXP-VINuG7TOdJyZ50Id_8JcX4-8CO2qSHs6jXy6sgHX6FcCm38E-1xEhNQWT7OX8IX3Tlf3idSWKoV_ZQQX1z5GXGFbOgLxnr8aY2yuhmnTHNIWeMoIHIIw9bODagKMBAC-O-uvyr_ejcMRNahcoukQD2aHFjykUNKjkCSKGMi8dAWhVr3rl0YfFR5CNzo_vA3s2TICHYZvD3_jik07OXnIK04otAHH6mfBypHuA4ThRAGU_twW-FLjEDYfh5Rxm9DuIcsPM7fGr28pbZSIw8nnPYf41fVzBIzAhvcD3mpjqkU6C4vXq-LEd6mBNIGMOsYUm_83Ae1gpBgexu4m4O5DUroILAjclEIG-meF5PHmR3uGOt-oiJ39jXUo73HdU5bGMPzSzC5TNxKZMIdY3yCReNUCKvjr4R_YJbfTKn7aJS2prs_1kQ_0akcICiPmrXAng5nQV32MiqCiq2xu9kxk0Y", - "testdo@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.heenig4R9kb4aIPuQkbmjVUwHUQPL7yMwZi9EZ3YMkpVWKwDb130-BX0ovYpxzkfDZoIlh_fUYIzfEPrqyUt2HvVn93JgJJp8le9JndNREjoq-t01RGyh5ieJBmmTxXPJ7X17HdCcAGnB0DNBFUNk72nW0qCmgCHAypey1QPi8iiINTuKYRHCy9yqC0hyQWkrUzue8LxGxPCD3AdKn7P19yQMTPDF4TRC-gzDSm6AurbIEJjnyGe2YIhI_w-m6qnHg5qlpw6CkaEa-yosD6tmCuxSRZ7-nQna9S2FdQRHl9344r5dzaN4tve8k3ZKsdtz_tfNuwAZAmTUMpExsNkmkDyTlQ6XZ_9u1PD0nEaeusZR9PBPEwEB3JQvC_lOMg2rB-gPlB1PPxNCS8M9WivqFelwsK_ddYhymQrcB6ZqHQe2ITiQKX7ONtqScj3JoNuGvt9nmO_JMC3GJlJSPuEm7ZbuIProTrE7aZpEPztYiyOHRQzGA-qsLM17so-QcRB6K2Z-XL3AndWsAfd1FaRxTtTVR061yTb1rxW35ChoHz4T-ImSXk04wY2SuW6oe1VJdlE494UM1WdPUURjXtM-bOGtZbnolSTLE-3DyUcrYB9I5GCUEDMkVOERXX7z5nsU8TUohUxwe57lreCVWABwLZZBOmf1gSEPejX3HuKMfA", - "aggowner@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoiYWdnb3duZXJAZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6ImFnZ293bmVyIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNzA5MTYxMzM2LCJleHAiOjExNzA5MTYxMzM2fQ.K_OcpQ5gxnJlg-G_-qt_YFITTnR2NqBwcSsRxK-XmF7x8nf3eiZAJ01pDeGRutLMyUfMHbDaGdIds0L30z86Bb2QAOnrh3Zfq25xQpyxg3-hXiRGeMURC5CfjTuibAsWyi1514uC9JpRgWrnrzkI8m_AaiElRW2rSgp_QZ08tDvc9ktSIJzGXACMiWsamWpTSvf4mu0Snp_Xe1K0yPZrUBczeNaOk8Qus-OhfYXHuKd4kiuIynUgaH9SiylQj8-wzb8zi6iW3x1qsfY_qexo1EHc5b-58-NV-0VY0PCmWIIfPQlxustn2h4xr8j80xP-1LnSxvG_0IOXeuiW88IVXy1KzutwiMuRN3qGusoLspRVWqQh5aVLGsRN0VUzJAUpPmUSCjp6VacFh3oRWK2vdM4R8ZcGSwIrlup_5mCq7_hGxQXTu1gZ1e9hW3LYYnuCTFrPPCzqqGSeXkHrb37oN85yW-JdL2vBoGnhNSxZIMAgMQK9Cq4BR1ZdDZpKu4MXkzC_L4ZW3vVlOpgmKJ9oYDrJtTwaPuVZ3ratIlZPNt7Dp-4z0MTATUcvyjxYSxq1qFH-SnA1pnmA6OVEIKgoCn9rjEfg2v_y6ad2fONfg4czVGJ46HFcYmJIZ84Os9ME6yAlciqNR89Qix9NNRvCbfd39Y5K40K_25l01S3JcgA", - "traincol1@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidHJhaW5jb2wxQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDEiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.p6T5fX5xLGZnxyTW1258CpjCFJ_IA3PyR-h7aR0uxQPM1SeVDECfEWU85YGsnKlxvZJItsmNqAgvQE-cCUnC5DOq5Fo6XNEofbGi_LCDOgPxZ0ywNkrQk_H1S3VZqt4ukU2smfh_0biNI3qvWkEQL__VnStRPm9UGFP9ZoKPqbKxL81oM16FaWD7Ahy9ZofgD4v7fAvp8Yvmmw4P6LjIqB0GG7H2t1Tg8qwgJV_i0zZ5yddVUcn08EZr1H4xlkk7FhpFLCyMSQ9NKJvT_DkHXwuzmsxn3DLOr5lfQpJLM7DfJekmFHJwqCtbC1YsDQ5RZCcbjMdpXQyE2WWs6zraGelL1DcRi9chuZSCQD4y2yfhm8eCW7rrrmvUYVP0kIwt0Bj9TvjzBjLmFpMgY23RPV-iI6bcqcU7p7k3-IK6WO291O4iDM_udY5lHqm2wus-57pVBhSRQ4TOvrcwD6K9uX64j9zbGDplhOkhYGp2M9WsWckWazYQt5tvSVj_63cTAjDewQvtv2UldriGCDTIQzGJaHnOqwSK3iXDV2KwGT_m-LEdvGTCpusi9c5J5mFVk8QHoFZbLNKW9Yo6KfxVZbbPf0p2FPjnLaSjcWFPkQOLrQLC8N8NhBR4i5lRSVmc7i8uitwyOYJT52wEUrnEZlgx3uh9Bv0ZZaVDCzK6HDk", - "traincol2@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidHJhaW5jb2wyQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0cmFpbmNvbDIiLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDkxNjEzMzYsImV4cCI6MTE3MDkxNjEzMzZ9.q6atpv6-DYSt0-F3a2Pb5OWKF6GjalX1eAASmNv-hn8vQtPGCjYcJ2EFC9FrR2Oc8loiFSPmu4POkLMxdXzsX436_chd8K28xssdrqodexvDrHr5W-3NK4LhDsNZoeA3D9Sg3zIP80q94YNXmCKNmZT4wuRmcH5pdAQoj5DantfKe-Ctb0tZb1p4Qt4XErL6BkY1X7Qw9UsKVTM3T9Aeb3unTV_m6k-jzMHpxqA4VrUS9KmfZS_0bpv4Po-X16umvNTiAxUchXTPHrn2uNhFz-4xxwFoGxGBxnWJXpBNe4JaJK-_yKMRI6qBa8A-t9wXlj7Fa-ygDYglenFkTItmIUZG27WJEglU6Ue4-AjduBwAMAvXNzEZUhKsdEns4yPV9DiGxlJfKqBNHQacx89SweRGqdm5qxPThFxtZw3zNbX1d9-THYVJnAbm8VuvklAI-4PxPghJP-ge2JUskutpuRDuGh-mUSKARc4c1OZdwxurycYkYgJuJ_a5NJbcq6s34TuF4acRh_hoNtNqA_CmxvvG8QLYLtTpO1nLFyMNxNEHVVTeKi5PkdNq2IEGwTR5M1nW-OSF4nwS3Hh9S_X7PvdTaEZTPFcjjvnBwghykLsi_j07xl2J2Ys9QYs_LANtV6dldnPBKKLg5FgF6rgly-oG5p0KWpq1lU4w7Jsq6Fc", - "testcol@example.com": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGNvbEBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGNvbCIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTcwOTE2MTMzNiwiZXhwIjoxMTcwOTE2MTMzNn0.RiHDW8kT60yzFgIyoGt3QS6ok6VAzWLEYFe_Tr_qZV-IkWJULic9fAYkajx8fg4a1jLFRaeClGs6M3p0pj4izixZnCZ7v6OZFOuk-qeymtgSgRoo5gU7fzcSCzW34AHIC_5x2FlXY7Kp34r3pBzsWwF2KiNufeywdDP8bxWgz7_hwOf5Ne6BnKj1UtlIWd_u6wIHQ-OHX0bWQM_fqQCBqaP-qxkOCsmwNHqXf3JLs-fsi9BzPBORlBIl6jluU9gCsLBlRZ9CUv5zrTiLZUREVFmKy8C-eQE6TEPKbUxnWvOggYSnXPf8LsI4hLXxki8RdSUdDcG7Qe4bPUs1EAA1TNhHwlphGXnbPWSK6VAFFfsYDHwNZUEH3yYLXvRQcJGmcJOaIvuXOJiQ6kUYkijlNVrqLVFd7xotjCoavftdrj2uizBtcWEZTD1Sd1gM0o7Ok-B2BYzJXfPx6lNZu6k1oXcdTulkrkmFEKPvYG1WjxAu7FxLxoWLwQBUFeG2gtyAWUeh9gJ-KMvm0BOgXalSz0GQY7F5IDGhucKHxGA3k1RBCKEHz2KTZ4x1vle5rYR6jZPPb8iTikz6KaRTBF7WZTGd3XgDXNx-K2hdH3A3CjBYQO8aEQgPnwpKEn9I5VoVhT_QCri_OP4S3N82OpcuciAOoXu5y6BAS816Un2gXho" + "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.ay8fatKc8rjT-Mu08QMz00D8BXmRc74M-02KZqdL8dR71CX6rD2DROSQ9wvf2sgHANcFoNWkYyr8S-Su4DqOPV87L2Jczs2tIPLVSEW28mYrR8YPysNHsSUh3eKi-7wX8F_gxpOhRdjo3Mqa_t3tw5ANfFrRVRl6SF8Mq9mOzirO9dcT3ya4WEGumBrszpJXBWJJxiNr9et1QVBKASUJVY2eclDUiK5vnokIS1nHrPL0sVos1Glcj9gtHyITmm2op7snoMuS65sjAD4dRl08XtB_amOoSZfzYmL8zqQXDYcxgX7zlJsuEsQ28Wm7XjG8tULwLK1XStexFSgL_Kp8HNUuDmWaJ2u-rCpW01Xg86VjQRuVm_eOTRDu8P6h9r0x2f5JykYrdqYw6pDUcryc8MYpceldFx1XZVv0-Fm_5LtKYCh7P9hlN-ND7soR-qeNZ8HWCOVIOYb4257SZ1rhO6Z5qDFJEvrr5aNYzNXLD82mIqTJPN6nLbIVo_EnkHIhPK9QIkfW-tGxgvxfn6zPZNBmWa4EdIRgXx-NTRbCM5SzNmqF89_jZjdIQBdB5OFSdpToFJQ8VvdCaMSPZodTHbv4JZ2sHt0Q9byYMDjVKL6lE_4nyWkVc4X72CG_2LnbkUozqPgoG0_pyeaBvK0oZKAo6YWkifCDZbzUtRh6RKk", + "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.oxbLFiBx1SUHL3QY1v7OcwZVt_WPa_4yNhNCQP4wb78d8QhIrTspkC4DJQjWSQIFtnozbcCwmuNbc-EE5E_VaABiOOUl4ayT6TvFf73oeTvNTLB7JDD8fGFGIAmvo5vwTgtNEbIpE6aBXVAsc93kcRhLEIAtpyj773sruSiLGew-nn2SvXbhW6W1_dh4u0uWBFfhYmbHix9TjxfqDR8437PA5pxr4kPObGvU7DoMm3LvUibLOOc5wu5KeVSaqSNXoZVO-a56ffEA8qdh89HjoDC7MbChDhS3kRzMTh_kVrTv1u29yi1VX7ayAfyGZt17s-R8NPEI5pkxs_gAtu8rc77wZOf8WBXVDlgqA-WZfLT_9jwgWgzl139x73Zdv5c3ptWzFdL2PdvBP4mnzAjA-53mKUqXFqeKbJUvC-P6wGC2k50A3__LOKvphHgx1p2inPaEaD1mqKStHFeb_v6PkDBNp6_654IX68Grnwm01pR67gzWnv3Y7mCZnfHXw-WA662rkPKySf1-ZOYkxw73WRlTQjYn7JL7MjpiiUKTe_zfAO7HvSE2qck2FeXn_iy4CqS49JV5Ur-bUTIr7j0rtftNpFQjmFJFBYpPtrO9r8CO58beYluKliwJkpw6YwphXEC7qvIWx89Xd3PS_A1IVBAaY2cm1cc5PRZrI1hvyh4", + "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.KEs-Zwhq9sGt92fRVIZ6k2wnZ1ij6vZ6zkUP9_ct0zHjm4LszodN5frGxF1NqI10Z-9GCfIButY1cDur78uY005gaEv63iGzETFYOsXaG7y63wzK03Dapwri8E9uPNdTXkiNCyUgd3PzRrcQsKgIBUJVCR2yL_KQacctU5Y9eXhGujQu6PzxFlBti5ajK2t-5sjOdpQOsOOQfRxVfavo1LXAsoEgZqsPnIYtWIC65wfj9gvyPqPjyMxK4jzEF2iszyCq8dq2ubs-7DsTOWTq4PpT9nvmu1h5Sl__q4edIFj8fpzNrY_r2-iKipRmfL6hxiZLGSY419Gn4iXnrD_kBjFQ7iOC7H-v_M_r2ORBm7WxObaKc_Y65Toy2swh_aKorcJzqgCqufkFU9JqTzolhMZStwuRWzT5sXE7434eNzOm7ogb_ogBB2zLyvgLpGIYUEkmQE62tYoT2USEZKI_eTbyxDOxY-WOc-aE2CjBX8K3v2gH2Se1TxVJDa67mAVH7D0XltJIVsYK14Vt989C3K9ZSThWRJciwhXGlHsOVWTL2Wkr9OQxIYEKrREBDF6cxMdz6JrEvoan8DdqYzRDSRhEmFuFUYXZNuXGa2Yz8uvjGzN4rgupOKG2-PMZUH9DgJvqU_6rY_6VGJzyqVGu9HmAlMKvRTr27kw-_UAFewg", + "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.ZBvEnVHNJtVZBBoRsuMJng7TY1MMJ67yYMpwprJBvTvjAZmltnC5_WlXdWnBcoWvnxj1rHcB1LD6pY59TWzYxtchGGmkcWqo4nVjUVSi075vfu3xgPkasXaR3PK9xdTH0aHEcyU4FsNOpnjo1YBqnYS5jMg9Rvz6MGuww1m0coED1sS7YFcKaDj-qtWYxOz85PUyxLLu9REvTdgjvFfr4HdC9bNx5YEcDgpbt7_ZKDiPhzSOFN8ed6R-wyR3gQCW9sEIFjFeO9t-gbPhsZymfELSAHIQ0JWP0hfQavCAN9XQqayrT3wQMCgiM9hTXylnLA_A2_IrpfeInASuDY-CsS0YbkF3iAVes_mxFmokC_dOWPJs2b7P-bXo6gbPpLGkUhzCYREBkfMU1wIqur5IC5PGpKXF4B3I-BSt9knD4a2DXFSkxyxOJaJh4cxbnp3GmdLWP05f9Cuu6MDv79byX7Xq03XvcRTxbTYiiFgMWxGQV3YmdClZEMmPl7t870iZn6XSsVKjvcxihJEp8i4FCyI6v3LUfeg1uIZEIBK04TntdmyW7u_3uE3wdVbTeSvoaW2GGCD5kNyhfVtJtIdyYgtIUQ7JOjHMvKdmj_keVPuBjrfD1QM4MdeYJGtJ_QcZ9LcpxWrbg3FIsA79WC3_qw4pUUdaL0ydSmB6V0mrg5E", + "testdo2@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvMkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGRvMiIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTcwOTY5NDEzMCwiZXhwIjoxMTcwOTY5NDEzMH0.SzzdojAzlU7wTucVgY4XJlffVYfdlQhuPJH0ySZH38p6wwnxYNpgHlj3RkE_5p1IIwmIRqpyBxhDO_W6PDOhE4JofwKadoEPoQ5N684wwWKnQr0NI-71gmyGj2sng1BDX0Lpi70yv6iP__OqLAR8tlcIEv7flCy3qpppIxhqBoybE3XBRMCBwrgyO3aurdAW2lZZOihorB9zUjaXlULyvLRxftQ30xosL8JeYfAWWFleHuxJfOK5X_F2vMcsU89jsDf93YMWtyDQBGhFHVXHTA8VLazc1ve5DCkXpCZU0qSo3Fg-8bhrOhZbManxdhLxU-qwvND_uAjch4OC_uPHLUBHMhkWuaa6Ift0EEvgvBS_-0LxplkMiP7pk2YCqpnjB01_1SMHgz2ubAAaTdmd3oj9JcZzRSds9-kFhTcdHA6B4Cx0ZxZPFOhdt2IPCk0D5MRN35-1ZaLwwCEi1NK2XcG7P06_HGZKUV_f6B9enkCevyIr_XDnkGvPet-9kh7y-61ee9qZ1xtruaofw9D2iNaP8V07eCmydjn4zIpk3QehdRofCDF62f-yYYxy6h1GiNH6No0ROsCLBOwLd0--TAxOGtpupLYE_Pet5CleUuQvcPDYuvMR5Hq91UPw8p-_ejW6FbvIPmQio7mEJz_YMeEdHkCJZxNLk2oLQBz8WwM", + "testao@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.FfH2lvnqV7FGfUTzQNS-Y9zPA2SbO6onry_fOQtwY08nJJLvIYifMJfOe1tqUVjlud2bG_pFhbI9BYuhvRdt8Y1dpReIjcFkB3gLB1DYkzSwcXVfjdphVOZlWv0ZQTPAoh8Epu3zYtoJNbQukOGPfclMrzndNOWs4k1noZ-xtAu3j-iS3VJDJnIweYZvF_eHJum-xl3-js0mxbLssr1FSx2JZQUuYs6U-SO_gyVSmCpaNb7klMBgfYPZPO2GzN9Lxtv4INEvtg4J4nC5f_SPz8xs9efrWlmdrTsxD0h916Pv3u6hrBawcGzzS9javDlap8HOKgxMtdx9-auwsYZ1-UlcvBBqLJjGIbgAL2ncREpUHIOQIt2dWJyRQz5Xkl6uMjeW-BfAfIRM-oXGd-CObY7TuOlsUcA9VYQ9jvxp2f1bNGt0-Ib0PtnnzNhdEL6zuw1oUaEi-ST5xG_yHKvAa_xfZOncGINNPtvzh77RLVY4eWCcVnwV2OLGNYd0fxLyrm1TfGpUUY4Br_3_x9npeurY8twrkbqUuUvsXbB4TkKgSF8OnyCW-Khrg6t09UURrBYiHUa1jC2RJqMaBv-sUXzVlIU3EG30wheUNnzmctiAmh02XKEpmEE83go6XQc5h8n0hvyLbXSmaftUUBblkFwSFLzPRrg1qIti_lM_B9U" } \ No newline at end of file diff --git a/server/testing_medperf.sh b/server/testing_medperf.sh deleted file mode 100644 index 861a3dea4..000000000 --- a/server/testing_medperf.sh +++ /dev/null @@ -1,69 +0,0 @@ -# TODO: remove me from the repo - -# setup dev server or reset db -bash reset_db.sh -sudo rm -rf keys -sudo rm -rf /home/hasan/.medperf -# seed -python seed.py - - -medperf profile ls -medperf profile activate local -medperf profile ls - -# move folder as a created dataset -cp -r /home/hasan/work/openfl_ws/9d56e799a9e63a6c3ced056ebd67eb6381483381 /home/hasan/.medperf/localhost_8000/data/ - -# login -medperf auth login -e testbo@example.com - -# register mlcube -medperf mlcube submit -n testfl \ - -m https://storage.googleapis.com/medperf-storage/testfl/mlcube-cpu.yaml?v=1 \ - -p https://storage.googleapis.com/medperf-storage/testfl/parameters-cpu.yaml \ - -a https://storage.googleapis.com/medperf-storage/testfl/additional.tar.gz - - -# register training exp -medperf training submit -n testtrain -d testtrain -p 1 -m 5 - -# mark as approved -curl -sk -X PUT https://localhost:8000/api/v0/training/1/ -d '{"approval_status": "APPROVED"}' -H 'Content-Type: application/json' -H "Authorization: Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21fY2xhaW1zL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE2OTA1NTIxMjQsImV4cCI6MTE2OTA1NTIxMjR9.PbAxtzBxPfipnuYGPx90P2_K-2V7jPSdPEhzHEW6u4KnUQU8Gul6xrwLsGlgdD19A6EzUtgfQxW2Lk2OITcOD0nbXcjUgPyduLozMXDdTwom19429g7Q5eWOppWdMImirX3OygWaqx587Q_OL73HZuCjFcEWwyGnhB62oruVRcM6uDWz4xVmGcAwdtMzCBYvQj9_C-Hnt9IYPgnKesXPr_AP98-bdQx2EBahXtQW1HaARgabZp3SLaCDY9I6h91B7NQ-PDWpuDxd0UamHSaq9dNPbd0SsR6ajl80wOKQaZF3be_TKJW0e0l7L4tnsbbSW23fR1utSH2PlNFPBx3uGGe2Aqirdq16fAWqvDNO8-kiVRpeikp0ze17lTYqtw2-GZIxXyc8rG-NPxz7R5lMg7ARu99e5nLGFHpV5sMNUoXKx5zoPO7Y7cO5mdzm0C_2DARB7imagKsL5eLc5fcYDEZBl0FtkDgT_CY3FEuH_X3DgPwEP6wE2IFGnU1zEXtuNd1XSUxvxxZ0_afoX54qNuz3m9qzAKuYJkkziiApdIPE_bXX2ox3-Z_Q5RfqvtLRJoE64FaOMr_6xCq_77hpPDpWACQaXCwn736-Jl8nP1HcGvdDa980dzKaih4mQ-FtFZ8xhMXU7jA_Bur9e2tg51TxBzAyd4t4NNk-gYaSUPU" - -# register aggregator -medperf auth login -e testmo@example.com -medperf aggregator submit -n testagg -a hasan-HP-ZBook-15-G3 -p 50273 - -# associate aggregator -medperf aggregator associate -a 1 -t 1 - - -# register dataset -medperf auth login -e testdo@example.com -medperf dataset submit -d 9d56e799a9e63a6c3ced056ebd67eb6381483381 - -# associate dataset -medperf training associate_dataset -t 1 -d 1 - -# approve associations -medperf auth login -e testbo@example.com -medperf training approve_association -t 1 -a 1 -medperf training approve_association -t 1 -d 1 - - -# test nonimportant stuff -# medperf training ls -# medperf aggregator ls -# medperf training view 1 -# medperf aggregator view 1 -# medperf training list_associations - -# lock experiment -medperf training lock -t 1 - -# # start aggregator -gnome-terminal -- bash -c "medperf aggregator start -a 1 -t 1; bash" - -# # start collaborator -medperf training run -d 1 -t 1 diff --git a/server/testing_miccai.sh b/server/testing_miccai.sh deleted file mode 100644 index 70554af19..000000000 --- a/server/testing_miccai.sh +++ /dev/null @@ -1,153 +0,0 @@ -# TODO: remove me from the repo - -# First, run the local server -# cd ~/medperf/server -# sh setup-dev-server.sh -# go to another terminal - -cd .. -# # TODO: reset -# bash reset_db.sh -# sudo rm -rf keys -# sudo rm -rf ~/.medperf - -# TODO: seed -# python seed.py --demo benchmark - -# TODO: download data -# wget https://storage.googleapis.com/medperf-storage/testfl/data/col1.tar.gz -# tar -xf col1.tar.gz -# wget https://storage.googleapis.com/medperf-storage/testfl/data/col2.tar.gz -# tar -xf col2.tar.gz -# wget https://storage.googleapis.com/medperf-storage/testfl/data/test.tar.gz -# tar -xf test.tar.gz -# rm col1.tar.gz -# rm col2.tar.gz -# rm test.tar.gz - -# TODO: activate local profile -# medperf profile activate local - -# login -medperf auth login -e modelowner@example.com - -# register prep mlcube -medperf mlcube submit -n prep \ - -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_prep.yaml - -# register training mlcube -medperf mlcube submit -n testfl \ - -m https://storage.googleapis.com/medperf-storage/testfl/mlcube-cpu.yaml?v=2 \ - -p https://storage.googleapis.com/medperf-storage/testfl/parameters-miccai.yaml \ - -a https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz - -# register training exp -medperf training submit -n testtrain -d testtrain -p 1 -m 2 - -# mark as approved -bash admin_training_approval.sh - -# register aggregator -medperf aggregator submit -n testagg -a $(hostname --fqdn) -p 50273 - -# associate aggregator -medperf aggregator associate -a 1 -t 1 -y - - -# register dataset -medperf auth login -e traincol1@example.com -medperf dataset create -p 1 -d datasets/col1 -l datasets/col1 --name col1 --description col1data --location col1location -medperf dataset submit -d $(medperf dataset ls | grep col1 | tr -s " " | cut -d " " -f 1) -y - -# associate dataset -medperf training associate_dataset -t 1 -d 1 -y - -# shortcut -bash shortcut.sh - -# approve associations -medperf auth login -e modelowner@example.com -medperf training approve_association -t 1 -d 1 -medperf training approve_association -t 1 -d 2 - -# lock experiment -medperf training lock -t 1 - -# # start aggregator -gnome-terminal -- bash -c "medperf aggregator start -a 1 -t 1; bash" - -sleep 5 - -# # start collaborator 1 -medperf auth login -e traincol1@example.com -gnome-terminal -- bash -c "medperf training run -d 1 -t 1; bash" - -sleep 5 - -# # start collaborator 2 -medperf auth login -e traincol2@example.com -medperf training run -d 2 -t 1 - - -############### eval starts - - -# submit reference model -medperf auth login -e benchmarkowner@example.com -medperf mlcube submit -n refmodel \ - -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_other.yaml - -# submit metrics mlcube -medperf mlcube submit -n metrics \ - -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_metrics.yaml \ - -p https://storage.googleapis.com/medperf-storage/testfl/parameters_metrics.yaml - -# submit benchmark metadata -medperf benchmark submit --name pathmnistbmk --description pathmnistbmk \ - --demo-url https://storage.googleapis.com/medperf-storage/testfl/data/sample.tar.gz \ - -p 1 -m 3 -e 4 - -# mark as approved -bash admin_benchmark_approval.sh - -# submit trained model -medperf auth login -e modelowner@example.com -medperf mlcube submit -n trained \ - -m https://storage.googleapis.com/medperf-storage/testfl/mlcube_trained.yaml - -# participatemedperf benchmark submit -medperf mlcube associate -b 1 -m 5 -y - -# submit inference dataset -medperf auth login -e testcol@example.com -medperf dataset create -p 1 -d datasets/test -l datasets/test --name testdata --description testdata --location testdata -medperf dataset submit -d $(medperf dataset ls | grep test | tr -s " " | cut -d " " -f 1) -y - -# associate dataset -medperf dataset associate -b 1 -d 3 -y - -# approve associations -medperf auth login -e benchmarkowner@example.com -medperf association approve -b 1 -m 5 -medperf association approve -b 1 -d 3 - -# run inference -medperf auth login -e testcol@example.com -medperf benchmark run -b 1 -d 3 - -# submit result -medperf result submit -r b1m5d3 -y -medperf result submit -r b1m3d3 -y - - -# read results -medperf auth login -e benchmarkowner@example.com -medperf result view -b 1 - -############ test other stuff -medperf auth login -e modelowner@example.com -medperf training ls -medperf aggregator ls -medperf training view 1 -medperf aggregator view 1 -medperf training list_associations \ No newline at end of file diff --git a/server/testing_miccai_shortcut.sh b/server/testing_miccai_shortcut.sh deleted file mode 100644 index a611bc7df..000000000 --- a/server/testing_miccai_shortcut.sh +++ /dev/null @@ -1,7 +0,0 @@ -# register dataset -medperf auth login -e traincol2@example.com -medperf dataset create -p 1 -d ../../datasets_folder_final/col2 -l ../../datasets_folder_final/col1 --name col2 --description col2data --location col2location -medperf dataset submit -d 54ea1643f6006ead7e8517cd65fd5275f99abe7349895be25bd8485761cde088 -y - -# associate dataset -medperf training associate_dataset -t 1 -d 2 -y From ca72d8e858a86713f7064b2e58d525b1d7a6d837 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 6 Mar 2024 17:35:13 +0100 Subject: [PATCH 012/242] debug --- .github/workflows/train-ci.yml | 5 +++++ cli/cli_tests_training.sh | 2 ++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml index 473e7460b..791960911 100644 --- a/.github/workflows/train-ci.yml +++ b/.github/workflows/train-ci.yml @@ -15,6 +15,11 @@ jobs: with: python-version: '3.9' + - name: debug + run: | + docker run hasan7/testhostname:1.0.0 $(hostname --fqdn) + docker run hasan7/testhostname:1.0.0 $(hostname -A | cut -d " " -f 1) + - name: Install dependencies working-directory: . run: | diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 491174d4c..842137b33 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -136,6 +136,8 @@ echo "=====================================" HOSTNAME=$(hostname -A | cut -d " " -f 1) medperf aggregator submit -n aggreg -a $HOSTNAME -p 50273 checkFailed "aggregator submission step failed" +medperf aggregator ls +medperf aggregator ls | grep aggreg | tr -s ' ' | cut -d ' ' -f 1 AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | cut -d ' ' -f 1) ########################################################## From c5ce774b256105c3c4a131d3451d510452391563 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 6 Mar 2024 17:37:44 +0100 Subject: [PATCH 013/242] debug --- .github/workflows/train-ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml index 791960911..d7da5cb29 100644 --- a/.github/workflows/train-ci.yml +++ b/.github/workflows/train-ci.yml @@ -18,7 +18,6 @@ jobs: - name: debug run: | docker run hasan7/testhostname:1.0.0 $(hostname --fqdn) - docker run hasan7/testhostname:1.0.0 $(hostname -A | cut -d " " -f 1) - name: Install dependencies working-directory: . From 77d099a44d77c00a2abf2f35443ed7da1f79e416 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 6 Mar 2024 17:39:53 +0100 Subject: [PATCH 014/242] debug --- .github/workflows/train-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml index d7da5cb29..643c75238 100644 --- a/.github/workflows/train-ci.yml +++ b/.github/workflows/train-ci.yml @@ -17,7 +17,7 @@ jobs: - name: debug run: | - docker run hasan7/testhostname:1.0.0 $(hostname --fqdn) + docker run hasan7/testhostname:1.0.0 $(hostname -A | cut -d " " -f 1) - name: Install dependencies working-directory: . From e1427de6f5269386182afdba7c27b01366c9be3e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 6 Mar 2024 17:44:01 +0100 Subject: [PATCH 015/242] debug --- .github/workflows/train-ci.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/train-ci.yml b/.github/workflows/train-ci.yml index 643c75238..473e7460b 100644 --- a/.github/workflows/train-ci.yml +++ b/.github/workflows/train-ci.yml @@ -15,10 +15,6 @@ jobs: with: python-version: '3.9' - - name: debug - run: | - docker run hasan7/testhostname:1.0.0 $(hostname -A | cut -d " " -f 1) - - name: Install dependencies working-directory: . run: | From b5edaff2ee4230a06a5845077de0f6acb8e8c280 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 12 Mar 2024 01:49:42 +0100 Subject: [PATCH 016/242] add prep cube for FL tests --- examples/fl/prep/README.md | 10 +++++ examples/fl/prep/build.sh | 1 + examples/fl/prep/clean.sh | 3 ++ examples/fl/prep/mlcube/.gitignore | 1 + examples/fl/prep/mlcube/mlcube.yaml | 40 ++++++++++++++++++ examples/fl/prep/project/Dockerfile | 11 +++++ examples/fl/prep/project/mlcube.py | 38 +++++++++++++++++ examples/fl/prep/project/prepare.py | 50 +++++++++++++++++++++++ examples/fl/prep/project/requirements.txt | 6 +++ examples/fl/prep/project/sanity_check.py | 11 +++++ examples/fl/prep/project/stats.py | 23 +++++++++++ examples/fl/prep/test.sh | 7 ++++ 12 files changed, 201 insertions(+) create mode 100644 examples/fl/prep/README.md create mode 100644 examples/fl/prep/build.sh create mode 100644 examples/fl/prep/clean.sh create mode 100644 examples/fl/prep/mlcube/.gitignore create mode 100644 examples/fl/prep/mlcube/mlcube.yaml create mode 100644 examples/fl/prep/project/Dockerfile create mode 100644 examples/fl/prep/project/mlcube.py create mode 100644 examples/fl/prep/project/prepare.py create mode 100644 examples/fl/prep/project/requirements.txt create mode 100644 examples/fl/prep/project/sanity_check.py create mode 100644 examples/fl/prep/project/stats.py create mode 100644 examples/fl/prep/test.sh diff --git a/examples/fl/prep/README.md b/examples/fl/prep/README.md new file mode 100644 index 000000000..b8bbdffad --- /dev/null +++ b/examples/fl/prep/README.md @@ -0,0 +1,10 @@ +# How to test + +1. download a dataset: + + - train: + - test: + +2. Extract the dataset. Place the folder `col1` under the workspace folder and rename it to `input_data`. +3. Create an empty folder named `input_labels` under the workspace folder. +4. Run `test.sh` diff --git a/examples/fl/prep/build.sh b/examples/fl/prep/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/prep/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/prep/clean.sh b/examples/fl/prep/clean.sh new file mode 100644 index 000000000..68beabb21 --- /dev/null +++ b/examples/fl/prep/clean.sh @@ -0,0 +1,3 @@ +rm -rf mlcube/workspace/data +rm -rf mlcube/workspace/labels +rm -rf mlcube/workspace/statistics.yaml diff --git a/examples/fl/prep/mlcube/.gitignore b/examples/fl/prep/mlcube/.gitignore new file mode 100644 index 000000000..f1981605f --- /dev/null +++ b/examples/fl/prep/mlcube/.gitignore @@ -0,0 +1 @@ +workspace \ No newline at end of file diff --git a/examples/fl/prep/mlcube/mlcube.yaml b/examples/fl/prep/mlcube/mlcube.yaml new file mode 100644 index 000000000..15c0133b9 --- /dev/null +++ b/examples/fl/prep/mlcube/mlcube.yaml @@ -0,0 +1,40 @@ +name: pathmnist data preparation MLCube +description: pathmnist data preparation MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/fl-test-prep:0.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + prepare: + parameters: + inputs: + { + data_path: input_data, + labels_path: input_labels, + } + outputs: { output_path: data/, output_labels_path: labels/ } + sanity_check: + parameters: + inputs: + { + data_path: data/, + labels_path: labels/, + } + statistics: + parameters: + inputs: + { + data_path: data/, + labels_path: labels/, + } + outputs: { output_path: { type: file, default: statistics.yaml } } diff --git a/examples/fl/prep/project/Dockerfile b/examples/fl/prep/project/Dockerfile new file mode 100644 index 000000000..91c477415 --- /dev/null +++ b/examples/fl/prep/project/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.9.16-slim + +COPY ./requirements.txt /mlcube_project/requirements.txt + +RUN pip3 install --no-cache-dir -r /mlcube_project/requirements.txt + +ENV LANG C.UTF-8 + +COPY . /mlcube_project + +ENTRYPOINT ["python3", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/prep/project/mlcube.py b/examples/fl/prep/project/mlcube.py new file mode 100644 index 000000000..2e3f03556 --- /dev/null +++ b/examples/fl/prep/project/mlcube.py @@ -0,0 +1,38 @@ +"""MLCube handler file""" +import typer +from prepare import prepare_dataset +from sanity_check import perform_sanity_checks +from stats import generate_statistics + +app = typer.Typer() + + +@app.command("prepare") +def prepare( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + output_path: str = typer.Option(..., "--output_path"), + output_labels_path: str = typer.Option(..., "--output_labels_path"), +): + prepare_dataset(data_path, labels_path, output_path, output_labels_path) + + +@app.command("sanity_check") +def sanity_check( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), +): + perform_sanity_checks(data_path, labels_path) + + +@app.command("statistics") +def statistics( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + out_path: str = typer.Option(..., "--output_path"), +): + generate_statistics(data_path, labels_path, out_path) + + +if __name__ == "__main__": + app() diff --git a/examples/fl/prep/project/prepare.py b/examples/fl/prep/project/prepare.py new file mode 100644 index 000000000..9bb44f418 --- /dev/null +++ b/examples/fl/prep/project/prepare.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import pandas as pd +import shutil +from tqdm import tqdm +from PIL import Image + + +def prepare_split(split, arrays, output_path, output_labels_path): + subfolder = os.path.join(output_path, "pathmnist") + os.makedirs(subfolder, exist_ok=True) + + arrs = arrays[f"{split}_images"] + labels = arrays[f"{split}_labels"] + csv_data = [] + for i in tqdm(range(arrs.shape[0])): + name = f"{split}_{i}.png" + out_path = os.path.join(subfolder, name) + Image.fromarray(arrs[i]).save(out_path) + record = { + "SubjectID": str(i), + "Channel_0": os.path.join("pathmnist", name), + "valuetopredict": labels[i][0], + } + csv_data.append(record) + + if split == "train": + csv_file = os.path.join(output_path, "train.csv") + if split == "val": + csv_file = os.path.join(output_path, "valid.csv") + if split == "test": + csv_file = os.path.join(output_path, "data.csv") + + pd.DataFrame(csv_data).to_csv(csv_file, index=False) + + if split == "test": + csv_file_in_labels = os.path.join(output_labels_path, "data.csv") + shutil.copyfile(csv_file, csv_file_in_labels) + + +def prepare_dataset(data_path, labels_path, output_path, output_labels_path): + os.makedirs(output_path, exist_ok=True) + os.makedirs(output_labels_path, exist_ok=True) + + file_path = os.path.join(data_path, "pathmnist.npz") + arrays = np.load(file_path) + for key in arrays.keys(): + if key.endswith("images"): + split = key.split("_")[0] + prepare_split(split, arrays, output_path, output_labels_path) diff --git a/examples/fl/prep/project/requirements.txt b/examples/fl/prep/project/requirements.txt new file mode 100644 index 000000000..751a77d17 --- /dev/null +++ b/examples/fl/prep/project/requirements.txt @@ -0,0 +1,6 @@ +typer==0.9.0 +numpy==1.26.0 +PyYAML==6.0 +Pillow==10.2.0 +pandas==2.2.1 +tqdm \ No newline at end of file diff --git a/examples/fl/prep/project/sanity_check.py b/examples/fl/prep/project/sanity_check.py new file mode 100644 index 000000000..5fb23d72c --- /dev/null +++ b/examples/fl/prep/project/sanity_check.py @@ -0,0 +1,11 @@ +import os + + +def perform_sanity_checks(data_path, labels_path): + images_files = os.listdir(os.path.join(data_path, "pathmnist")) + + assert all( + [image.endswith(".png") for image in images_files] + ), "images should be .png" + + print("Sanity checks ran successfully.") diff --git a/examples/fl/prep/project/stats.py b/examples/fl/prep/project/stats.py new file mode 100644 index 000000000..872ed6f02 --- /dev/null +++ b/examples/fl/prep/project/stats.py @@ -0,0 +1,23 @@ +import os +import yaml + + +def generate_statistics(data_path, labels_path, out_path): + # number of cases + cases = os.listdir(os.path.join(data_path, "pathmnist")) + if cases[0].startswith("test"): + statistics = { + "num_cases": len(cases), + } + else: + num_train_cases = len([file for file in cases if file.startswith("train")]) + num_val_cases = len([file for file in cases if file.startswith("val")]) + statistics = { + "num_train_cases": num_train_cases, + "num_val_cases": num_val_cases, + } + + + # write statistics + with open(out_path, "w") as f: + yaml.safe_dump(statistics, f) diff --git a/examples/fl/prep/test.sh b/examples/fl/prep/test.sh new file mode 100644 index 000000000..c16ee1ca0 --- /dev/null +++ b/examples/fl/prep/test.sh @@ -0,0 +1,7 @@ +# mlcube run --mlcube ./mlcube --task prepare +# mlcube run --mlcube ./mlcube --task sanity_check +# mlcube run --mlcube ./mlcube --task statistics + +medperf mlcube run --mlcube ./mlcube --task prepare -o ./logs_prep.log +medperf mlcube run --mlcube ./mlcube --task sanity_check -o ./logs_sanity.log +medperf mlcube run --mlcube ./mlcube --task statistics -o ./logs_stats.log From a8f2f94c7511772190344a215ef658421ed1350c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 12 Mar 2024 03:12:26 +0100 Subject: [PATCH 017/242] update and restructure fl example --- examples/fl/fl/.gitignore | 3 + examples/fl/fl/README.md | 6 + examples/fl/fl/build.sh | 1 + examples/fl/fl/clean.sh | 4 + examples/fl/fl/csr.conf | 23 + .../mlcube-cpu.yaml => fl/mlcube/mlcube.yaml} | 4 +- .../mlcube/workspace/parameters.yaml} | 10 +- examples/fl/fl/project/Dockerfile | 22 + examples/fl/fl/project/README.md | 38 + examples/fl/{ => fl}/project/aggregator.py | 31 +- examples/fl/fl/project/collaborator.py | 33 + examples/fl/fl/project/hooks.py | 101 +++ examples/fl/fl/project/mlcube.py | 117 +++ .../project}/requirements.txt | 1 + examples/fl/{ => fl}/project/utils.py | 101 +-- examples/fl/fl/setup_clean.sh | 4 + examples/fl/fl/setup_test.sh | 86 +++ examples/fl/fl/sync.sh | 7 + examples/fl/fl/test.sh | 3 + examples/fl/mlcube/mlcube-gpu.yaml | 41 -- examples/fl/mlcube/workspace/network.yaml | 5 - .../fl/mlcube/workspace/parameters-cpu.yaml | 137 ---- .../fl/mlcube/workspace/parameters-gpu.yaml | 137 ---- examples/fl/project/Dockerfile-CPU | 31 - examples/fl/project/Dockerfile-GPU | 28 - examples/fl/project/README.md | 43 -- examples/fl/project/collaborator.py | 29 - examples/fl/project/fl_workspace/.workspace | 2 - .../fl/project/fl_workspace/plan/defaults | 2 - examples/fl/project/hotfix.py | 667 ------------------ examples/fl/project/mlcube.py | 34 - 31 files changed, 509 insertions(+), 1242 deletions(-) create mode 100644 examples/fl/fl/.gitignore create mode 100644 examples/fl/fl/README.md create mode 100644 examples/fl/fl/build.sh create mode 100644 examples/fl/fl/clean.sh create mode 100644 examples/fl/fl/csr.conf rename examples/fl/{mlcube/mlcube-cpu.yaml => fl/mlcube/mlcube.yaml} (93%) rename examples/fl/{mlcube/workspace/parameters-miccai.yaml => fl/mlcube/workspace/parameters.yaml} (97%) create mode 100644 examples/fl/fl/project/Dockerfile create mode 100644 examples/fl/fl/project/README.md rename examples/fl/{ => fl}/project/aggregator.py (51%) create mode 100644 examples/fl/fl/project/collaborator.py create mode 100644 examples/fl/fl/project/hooks.py create mode 100644 examples/fl/fl/project/mlcube.py rename examples/fl/{project/fl_workspace => fl/project}/requirements.txt (52%) rename examples/fl/{ => fl}/project/utils.py (54%) create mode 100644 examples/fl/fl/setup_clean.sh create mode 100644 examples/fl/fl/setup_test.sh create mode 100644 examples/fl/fl/sync.sh create mode 100644 examples/fl/fl/test.sh delete mode 100644 examples/fl/mlcube/mlcube-gpu.yaml delete mode 100644 examples/fl/mlcube/workspace/network.yaml delete mode 100644 examples/fl/mlcube/workspace/parameters-cpu.yaml delete mode 100644 examples/fl/mlcube/workspace/parameters-gpu.yaml delete mode 100644 examples/fl/project/Dockerfile-CPU delete mode 100644 examples/fl/project/Dockerfile-GPU delete mode 100644 examples/fl/project/README.md delete mode 100644 examples/fl/project/collaborator.py delete mode 100644 examples/fl/project/fl_workspace/.workspace delete mode 100644 examples/fl/project/fl_workspace/plan/defaults delete mode 100644 examples/fl/project/hotfix.py delete mode 100644 examples/fl/project/mlcube.py diff --git a/examples/fl/fl/.gitignore b/examples/fl/fl/.gitignore new file mode 100644 index 000000000..6bd8bf2e2 --- /dev/null +++ b/examples/fl/fl/.gitignore @@ -0,0 +1,3 @@ +mlcube_* +ca +quick* diff --git a/examples/fl/fl/README.md b/examples/fl/fl/README.md new file mode 100644 index 000000000..d4228439d --- /dev/null +++ b/examples/fl/fl/README.md @@ -0,0 +1,6 @@ +# How to run tests + +- Run `setup_test.sh` just once to create certs and download required data. +- Run `test.sh` to start the aggregator and two collaborators. +- Run `clean.sh` to be able to rerun `test.sh` freshly. +- Run `setup_clean.sh` to clear what has been generated in step 1. diff --git a/examples/fl/fl/build.sh b/examples/fl/fl/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/fl/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/fl/clean.sh b/examples/fl/fl/clean.sh new file mode 100644 index 000000000..f7806d151 --- /dev/null +++ b/examples/fl/fl/clean.sh @@ -0,0 +1,4 @@ +rm -rf mlcube_agg/workspace/final_weights +rm -rf mlcube_agg/workspace/logs +rm -rf mlcube_col1/workspace/logs +rm -rf mlcube_col2/workspace/logs diff --git a/examples/fl/fl/csr.conf b/examples/fl/fl/csr.conf new file mode 100644 index 000000000..3285aed9f --- /dev/null +++ b/examples/fl/fl/csr.conf @@ -0,0 +1,23 @@ +[ req ] +default_bits = 3072 +prompt = no +default_md = sha384 +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[ req_distinguished_name ] +commonName = hasan-hp-zbook-15-g3.home + +[ req_ext ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names + +[ alt_names ] +DNS.1 = hasan-hp-zbook-15-g3.home + +[ v3_client ] +extendedKeyUsage = critical,clientAuth + +[ v3_server ] +extendedKeyUsage = critical,serverAuth diff --git a/examples/fl/mlcube/mlcube-cpu.yaml b/examples/fl/fl/mlcube/mlcube.yaml similarity index 93% rename from examples/fl/mlcube/mlcube-cpu.yaml rename to examples/fl/fl/mlcube/mlcube.yaml index a32d885cb..f7a67b805 100644 --- a/examples/fl/mlcube/mlcube-cpu.yaml +++ b/examples/fl/fl/mlcube/mlcube.yaml @@ -8,11 +8,11 @@ platform: docker: # Image name - image: hasan7/fltest:0.0.0-cpu + image: mlcommons/medperf-fl:1.0.0 # Docker build context relative to $MLCUBE_ROOT. Default is `build`. build_context: "../project" # Docker file name within docker build context, default is `Dockerfile`. - build_file: "Dockerfile-CPU" + build_file: "Dockerfile" tasks: train: diff --git a/examples/fl/mlcube/workspace/parameters-miccai.yaml b/examples/fl/fl/mlcube/workspace/parameters.yaml similarity index 97% rename from examples/fl/mlcube/workspace/parameters-miccai.yaml rename to examples/fl/fl/mlcube/workspace/parameters.yaml index a9ec969e5..04acd542c 100644 --- a/examples/fl/mlcube/workspace/parameters-miccai.yaml +++ b/examples/fl/fl/mlcube/workspace/parameters.yaml @@ -5,7 +5,7 @@ plan: db_store_rounds: 2 init_state_path: save/classification_init.pbuf last_state_path: save/classification_last.pbuf - rounds_to_train: 3 + rounds_to_train: 2 write_logs: true template: openfl.component.Aggregator assigner: @@ -107,7 +107,7 @@ plan: nested_training: testing: 1 validation: -5 - num_epochs: 5 + num_epochs: 2 opt: adam optimizer: type: adam @@ -118,7 +118,7 @@ plan: - 128 - 128 - 1 - patience: 5 + patience: 1 pin_memory_dataloader: false print_rgb_label_warning: true q_max_length: 5 @@ -135,8 +135,8 @@ plan: track_memory_usage: false verbose: false version: - maximum: 0.0.14 - minimum: 0.0.14 + maximum: 0.0.19 + minimum: 0.0.19 weighted_loss: true train_csv: train_path_full.csv val_csv: val_path_full.csv diff --git a/examples/fl/fl/project/Dockerfile b/examples/fl/fl/project/Dockerfile new file mode 100644 index 000000000..90d680efc --- /dev/null +++ b/examples/fl/fl/project/Dockerfile @@ -0,0 +1,22 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 +ENV CUDA_VISIBLE_DEVICES="0" + + +# install project dependencies +RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +RUN pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 && \ + pip install --no-cache-dir openvino-dev==2023.0.1 && \ + git clone https://github.com/mlcommons/GaNDLF.git && \ + cd GaNDLF && git checkout 64962a1d3416071299a452126baccd3163f0b2d8 && \ + pip install --no-cache-dir -e . + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube workspace +COPY . /mlcube_project + +ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/fl/project/README.md b/examples/fl/fl/project/README.md new file mode 100644 index 000000000..21ad970e9 --- /dev/null +++ b/examples/fl/fl/project/README.md @@ -0,0 +1,38 @@ +# How to configure container build for your application + +- List your pip requirements in `requirements.txt` +- List your software requirements in `Dockerfile` +- Modify the functions in `hooks.py` as needed. (Explanation TBD) + +# How to configure container for custom FL software + +- Change the base Docker image as needed. +- modify `aggregator.py` and `collaborator.py` as needed. Follow the implemented schema steps. + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +cd openfl +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd .. +bash build.sh +``` + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# How to customize + +TBD diff --git a/examples/fl/project/aggregator.py b/examples/fl/fl/project/aggregator.py similarity index 51% rename from examples/fl/project/aggregator.py rename to examples/fl/fl/project/aggregator.py index 61d583a12..fd7e60a5d 100644 --- a/examples/fl/project/aggregator.py +++ b/examples/fl/fl/project/aggregator.py @@ -5,11 +5,12 @@ prepare_plan, prepare_cols_list, prepare_init_weights, + create_workspace, get_weights_path, - WORKSPACE, ) import os +import shutil from subprocess import check_call from distutils.dir_util import copy_tree @@ -24,27 +25,33 @@ def start_aggregator( network_config, collaborators, ): - prepare_plan(parameters_file, network_config) - prepare_cols_list(collaborators) - prepare_init_weights(input_weights) - fqdn = get_aggregator_fqdn() - prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}") - prepare_ca_cert(ca_cert_folder) - check_call(["fx", "aggregator", "start"], cwd=WORKSPACE) + workspace_folder = os.path.join(output_logs, "workspace") + create_workspace(workspace_folder) + prepare_plan(parameters_file, network_config, workspace_folder) + prepare_cols_list(collaborators, workspace_folder) + prepare_init_weights(input_weights, workspace_folder) + fqdn = get_aggregator_fqdn(workspace_folder) + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + check_call(["fx", "aggregator", "start"], cwd=workspace_folder) # TODO: check how to copy logs during runtime. # perhaps investigate overriding plan entries? # NOTE: logs and weights are copied, even if target folders are not empty - copy_tree(os.path.join(WORKSPACE, "logs"), output_logs) + copy_tree(os.path.join(workspace_folder, "logs"), output_logs) # NOTE: conversion fails since openfl needs sample data... - # weights_paths = get_weights_path() + # weights_paths = get_weights_path(fl_workspace) # out_best = os.path.join(output_weights, "best") # out_last = os.path.join(output_weights, "last") # check_call( # ["fx", "model", "save", "-i", weights_paths["best"], "-o", out_best], - # cwd=WORKSPACE, + # cwd=workspace_folder, # ) - copy_tree(os.path.join(WORKSPACE, "save"), output_weights) + copy_tree(os.path.join(workspace_folder, "save"), output_weights) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) diff --git a/examples/fl/fl/project/collaborator.py b/examples/fl/fl/project/collaborator.py new file mode 100644 index 000000000..0762f5b82 --- /dev/null +++ b/examples/fl/fl/project/collaborator.py @@ -0,0 +1,33 @@ +from utils import ( + get_collaborator_cn, + prepare_node_cert, + prepare_ca_cert, + prepare_plan, + create_workspace, +) +import os +import shutil +from subprocess import check_call + + +def start_collaborator( + data_path, + labels_path, + parameters_file, + node_cert_folder, + ca_cert_folder, + network_config, + output_logs, +): + workspace_folder = os.path.join(output_logs, "workspace") + create_workspace(workspace_folder) + prepare_plan(parameters_file, network_config, workspace_folder) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + # set log files + check_call(["fx", "collaborator", "start", "-n", cn], cwd=workspace_folder) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) diff --git a/examples/fl/fl/project/hooks.py b/examples/fl/fl/project/hooks.py new file mode 100644 index 000000000..dfa4792b5 --- /dev/null +++ b/examples/fl/fl/project/hooks.py @@ -0,0 +1,101 @@ +import os +import pandas as pd +from utils import get_collaborator_cn + + +def __modify_df(df): + # gandlf convention: labels columns could be "target", "label", "mask" + # subject id column is subjectid. data columns are Channel_0. + # Others could be scalars. # TODO + labels_columns = ["target", "label", "mask"] + data_columns = ["channel_0"] + subject_id_column = "subjectid" + for column in df.columns: + if column.lower() == subject_id_column: + continue + if column.lower() in labels_columns: + prepend_str = "labels/" + elif column.lower() in data_columns: + prepend_str = "data/" + else: + continue + + df[column] = prepend_str + df[column].astype(str) + + +def collaborator_pre_training_hook( + data_path, + labels_path, + parameters_file, + node_cert_folder, + ca_cert_folder, + network_config, + output_logs, +): + cn = get_collaborator_cn() + workspace_folder = os.path.join(output_logs, "workspace") + + target_data_folder = os.path.join(workspace_folder, "data", cn) + os.makedirs(target_data_folder, exist_ok=True) + target_data_data_folder = os.path.join(target_data_folder, "data") + target_data_labels_folder = os.path.join(target_data_folder, "labels") + target_train_csv = os.path.join(target_data_folder, "train.csv") + target_valid_csv = os.path.join(target_data_folder, "valid.csv") + + os.symlink(data_path, target_data_data_folder) + os.symlink(labels_path, target_data_labels_folder) + train_csv = os.path.join(data_path, "train.csv") + valid_csv = os.path.join(data_path, "valid.csv") + + train_df = pd.read_csv(train_csv) + __modify_df(train_df) + train_df.to_csv(target_train_csv, index=False) + + valid_df = pd.read_csv(valid_csv) + __modify_df(valid_df) + valid_df.to_csv(target_valid_csv, index=False) + + data_config = f"{cn},data/{cn}" + plan_folder = os.path.join(workspace_folder, "plan") + os.makedirs(plan_folder, exist_ok=True) + data_config_path = os.path.join(plan_folder, "data.yaml") + with open(data_config_path, "w") as f: + f.write(data_config) + + +def collaborator_post_training_hook( + data_path, + labels_path, + parameters_file, + node_cert_folder, + ca_cert_folder, + network_config, + output_logs, +): + pass + + +def aggregator_pre_training_hook( + input_weights, + parameters_file, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + network_config, + collaborators, +): + pass + + +def aggregator_post_training_hook( + input_weights, + parameters_file, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + network_config, + collaborators, +): + pass diff --git a/examples/fl/fl/project/mlcube.py b/examples/fl/fl/project/mlcube.py new file mode 100644 index 000000000..d88a8eb29 --- /dev/null +++ b/examples/fl/fl/project/mlcube.py @@ -0,0 +1,117 @@ +"""MLCube handler file""" + +import os +import shutil +import typer +from collaborator import start_collaborator +from aggregator import start_aggregator +from hooks import ( + aggregator_pre_training_hook, + aggregator_post_training_hook, + collaborator_pre_training_hook, + collaborator_post_training_hook, +) + +app = typer.Typer() + + +def _setup(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + os.makedirs(tmp_folder, exist_ok=True) + # TODO: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmp_folder + + +def _teardown(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + +@app.command("train") +def train( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + parameters_file: str = typer.Option(..., "--parameters_file"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + network_config: str = typer.Option(..., "--network_config"), + output_logs: str = typer.Option(..., "--output_logs"), +): + _setup(output_logs) + collaborator_pre_training_hook( + data_path=data_path, + labels_path=labels_path, + parameters_file=parameters_file, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + network_config=network_config, + output_logs=output_logs, + ) + start_collaborator( + data_path=data_path, + labels_path=labels_path, + parameters_file=parameters_file, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + network_config=network_config, + output_logs=output_logs, + ) + collaborator_post_training_hook( + data_path=data_path, + labels_path=labels_path, + parameters_file=parameters_file, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + network_config=network_config, + output_logs=output_logs, + ) + _teardown(output_logs) + + +@app.command("start_aggregator") +def start_aggregator_( + input_weights: str = typer.Option(..., "--input_weights"), + parameters_file: str = typer.Option(..., "--parameters_file"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + output_logs: str = typer.Option(..., "--output_logs"), + output_weights: str = typer.Option(..., "--output_weights"), + network_config: str = typer.Option(..., "--network_config"), + collaborators: str = typer.Option(..., "--collaborators"), +): + _setup(output_logs) + aggregator_pre_training_hook( + input_weights=input_weights, + parameters_file=parameters_file, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + network_config=network_config, + collaborators=collaborators, + ) + start_aggregator( + input_weights=input_weights, + parameters_file=parameters_file, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + network_config=network_config, + collaborators=collaborators, + ) + aggregator_post_training_hook( + input_weights=input_weights, + parameters_file=parameters_file, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + network_config=network_config, + collaborators=collaborators, + ) + _teardown(output_logs) + + +if __name__ == "__main__": + app() diff --git a/examples/fl/project/fl_workspace/requirements.txt b/examples/fl/fl/project/requirements.txt similarity index 52% rename from examples/fl/project/fl_workspace/requirements.txt rename to examples/fl/fl/project/requirements.txt index 709016a50..c7ec9886d 100644 --- a/examples/fl/project/fl_workspace/requirements.txt +++ b/examples/fl/fl/project/requirements.txt @@ -1 +1,2 @@ onnx==1.13.0 +typer==0.9.0 \ No newline at end of file diff --git a/examples/fl/project/utils.py b/examples/fl/fl/project/utils.py similarity index 54% rename from examples/fl/project/utils.py rename to examples/fl/fl/project/utils.py index 52c1eeeca..9cd42fa25 100644 --- a/examples/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -1,12 +1,21 @@ import yaml import os -import pandas as pd -WORKSPACE = os.environ["FL_WORKSPACE"] +def create_workspace(fl_workspace): + plan_folder = os.path.join(fl_workspace, "plan") + workspace_config = os.path.join(fl_workspace, ".workspace") + defaults_file = os.path.join(plan_folder, "defaults") -def get_aggregator_fqdn(): - plan_path = os.path.join(WORKSPACE, "plan", "plan.yaml") + os.makedirs(plan_folder, exist_ok=True) + with open(defaults_file, "w") as f: + f.write("../../workspace/plan/defaults\n\n") + with open(workspace_config, "w") as f: + f.write("current_plan_name: default\n\n") + + +def get_aggregator_fqdn(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") plan = yaml.safe_load(open(plan_path)) return plan["network"]["settings"]["agg_addr"].lower() @@ -17,8 +26,8 @@ def get_collaborator_cn(): return os.environ["COLLABORATOR_CN"] -def get_weights_path(): - plan_path = os.path.join(WORKSPACE, "plan", "plan.yaml") +def get_weights_path(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") plan = yaml.safe_load(open(plan_path)) return { "init": plan["aggregator"]["settings"]["init_state_path"], @@ -27,7 +36,7 @@ def get_weights_path(): } -def prepare_plan(parameters_file, network_config): +def prepare_plan(parameters_file, network_config, fl_workspace): with open(parameters_file) as f: params = yaml.safe_load(f) if "plan" not in params: @@ -36,42 +45,47 @@ def prepare_plan(parameters_file, network_config): network_config_dict = yaml.safe_load(f) plan = params["plan"] plan["network"]["settings"].update(network_config_dict) - target_plan_folder = os.path.join(WORKSPACE, "plan") + target_plan_folder = os.path.join(fl_workspace, "plan") + # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "plan.yaml") with open(target_plan_file, "w") as f: yaml.dump(plan, f) -def prepare_cols_list(collaborators_file): +def prepare_cols_list(collaborators_file, fl_workspace): with open(collaborators_file) as f: cols = f.read().strip().split("\n") - target_plan_folder = os.path.join(WORKSPACE, "plan") + target_plan_folder = os.path.join(fl_workspace, "plan") + # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "cols.yaml") with open(target_plan_file, "w") as f: yaml.dump({"collaborators": cols}, f) -def prepare_init_weights(input_weights): +def prepare_init_weights(input_weights, fl_workspace): error_msg = f"{input_weights} should contain only one file: *.pbuf" files = os.listdir(input_weights) - file = files[0] + file = files[0] # TODO: this may cause failure in MAC OS if len(files) != 1 or not file.endswith(".pbuf"): raise RuntimeError(error_msg) file = os.path.join(input_weights, file) - target_weights_subpath = get_weights_path()["init"] - target_weights_path = os.path.join(WORKSPACE, target_weights_subpath) + target_weights_subpath = get_weights_path(fl_workspace)["init"] + target_weights_path = os.path.join(fl_workspace, target_weights_subpath) target_weights_folder = os.path.dirname(target_weights_path) + # TODO: permissions os.makedirs(target_weights_folder, exist_ok=True) os.symlink(file, target_weights_path) -def prepare_node_cert(node_cert_folder, target_cert_folder_name, target_cert_name): +def prepare_node_cert( + node_cert_folder, target_cert_folder_name, target_cert_name, fl_workspace +): error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" files = os.listdir(node_cert_folder) @@ -89,7 +103,8 @@ def prepare_node_cert(node_cert_folder, target_cert_folder_name, target_cert_nam key_file = os.path.join(node_cert_folder, key_file) cert_file = os.path.join(node_cert_folder, cert_file) - target_cert_folder = os.path.join(WORKSPACE, "cert", target_cert_folder_name) + target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) + # TODO: permissions os.makedirs(target_cert_folder, exist_ok=True) target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") @@ -98,7 +113,7 @@ def prepare_node_cert(node_cert_folder, target_cert_folder_name, target_cert_nam os.symlink(cert_file, target_cert_file) -def prepare_ca_cert(ca_cert_folder): +def prepare_ca_cert(ca_cert_folder, fl_workspace): error_msg = f"{ca_cert_folder} should contain only one file: *.crt" files = os.listdir(ca_cert_folder) @@ -108,57 +123,9 @@ def prepare_ca_cert(ca_cert_folder): file = os.path.join(ca_cert_folder, file) - target_ca_cert_folder = os.path.join(WORKSPACE, "cert") + target_ca_cert_folder = os.path.join(fl_workspace, "cert") + # TODO: permissions os.makedirs(target_ca_cert_folder, exist_ok=True) target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") os.symlink(file, target_ca_cert_file) - - -def __modify_df(df): - # gandlf convention: labels columns could be "target", "label", "mask" - # subject id column is subjectid. data columns are Channel_0. - # Others could be scalars. # TODO - labels_columns = ["target", "label", "mask"] - data_columns = ["channel_0"] - subject_id_column = "subjectid" - for column in df.columns: - if column.lower() == subject_id_column: - continue - if column.lower() in labels_columns: - prepend_str = "labels/" - elif column.lower() in data_columns: - prepend_str = "data/" - else: - continue - - df[column] = prepend_str + df[column].astype(str) - - -def prepare_data(data_path, labels_path, cn): - target_data_folder = os.path.join(WORKSPACE, "data", cn) - os.makedirs(target_data_folder, exist_ok=True) - target_data_data_folder = os.path.join(target_data_folder, "data") - target_data_labels_folder = os.path.join(target_data_folder, "labels") - target_train_csv = os.path.join(target_data_folder, "train.csv") - target_valid_csv = os.path.join(target_data_folder, "valid.csv") - - os.symlink(data_path, target_data_data_folder) - os.symlink(labels_path, target_data_labels_folder) - train_csv = os.path.join(data_path, "train.csv") - valid_csv = os.path.join(data_path, "valid.csv") - - train_df = pd.read_csv(train_csv) - __modify_df(train_df) - train_df.to_csv(target_train_csv, index=False) - - valid_df = pd.read_csv(valid_csv) - __modify_df(valid_df) - valid_df.to_csv(target_valid_csv, index=False) - - data_config = f"{cn},data/{cn}" - plan_folder = os.path.join(WORKSPACE, "plan") - os.makedirs(plan_folder, exist_ok=True) - data_config_path = os.path.join(plan_folder, "data.yaml") - with open(data_config_path, "w") as f: - f.write(data_config) diff --git a/examples/fl/fl/setup_clean.sh b/examples/fl/fl/setup_clean.sh new file mode 100644 index 000000000..18ca5536c --- /dev/null +++ b/examples/fl/fl/setup_clean.sh @@ -0,0 +1,4 @@ +rm -rf ./mlcube_agg +rm -rf ./mlcube_col1 +rm -rf ./mlcube_col2 +rm -rf ./ca diff --git a/examples/fl/fl/setup_test.sh b/examples/fl/fl/setup_test.sh new file mode 100644 index 000000000..0285787e3 --- /dev/null +++ b/examples/fl/fl/setup_test.sh @@ -0,0 +1,86 @@ +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +# root ca +openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 +openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +# col1 +sed -i '/^commonName = /c\commonName = col1' csr.conf +sed -i '/^DNS\.1 = /c\DNS.1 = col1' csr.conf +cd mlcube_col1/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col2 +sed -i '/^commonName = /c\commonName = col2' csr.conf +sed -i '/^DNS\.1 = /c\DNS.1 = col2' csr.conf +cd mlcube_col2/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# agg +sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf +cd mlcube_agg/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# network file +echo "agg_addr: $HOSTNAME_" >>mlcube_col1/workspace/network.yaml +echo "agg_port: 50273" >>mlcube_col1/workspace/network.yaml +echo "address: $HOSTNAME_" >>mlcube_col1/workspace/network.yaml +echo "port: 50273" >>mlcube_col1/workspace/network.yaml + +cp mlcube_col1/workspace/network.yaml mlcube_col2/workspace/network.yaml +cp mlcube_col1/workspace/network.yaml mlcube_agg/workspace/network.yaml + +# cols file +echo "col1" >>mlcube_agg/workspace/cols.yaml +echo "col2" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +# weights download +cd mlcube_agg/workspace/ +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz +tar -xf init_weights_miccai.tar.gz +rm init_weights_miccai.tar.gz +cd ../../.. diff --git a/examples/fl/fl/sync.sh b/examples/fl/fl/sync.sh new file mode 100644 index 000000000..1a36e5ab5 --- /dev/null +++ b/examples/fl/fl/sync.sh @@ -0,0 +1,7 @@ +cp mlcube/workspace/parameters.yaml mlcube_agg/workspace/parameters.yaml +cp mlcube/workspace/parameters.yaml mlcube_col1/workspace/parameters.yaml +cp mlcube/workspace/parameters.yaml mlcube_col2/workspace/parameters.yaml + +cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col2/mlcube.yaml diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh new file mode 100644 index 000000000..f05b2551a --- /dev/null +++ b/examples/fl/fl/test.sh @@ -0,0 +1,3 @@ +gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273; bash" +gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col1 --task train -e COLLABORATOR_CN=col1; bash" +gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col2 --task train -e COLLABORATOR_CN=col2; bash" diff --git a/examples/fl/mlcube/mlcube-gpu.yaml b/examples/fl/mlcube/mlcube-gpu.yaml deleted file mode 100644 index 766c5b68c..000000000 --- a/examples/fl/mlcube/mlcube-gpu.yaml +++ /dev/null @@ -1,41 +0,0 @@ -name: FL MLCube -description: FL MLCube -authors: - - { name: MLCommons Medical Working Group } - -platform: - accelerator_count: 1 - -docker: - # Image name - image: hasan7/fltest:0.0.0-gpu - # Docker build context relative to $MLCUBE_ROOT. Default is `build`. - build_context: "../project" - # Docker file name within docker build context, default is `Dockerfile`. - build_file: "Dockerfile-GPU" - gpu_args: --gpus all - -tasks: - train: - parameters: - inputs: - data_path: data/ - labels_path: labels/ - node_cert_folder: node_cert/ - ca_cert_folder: ca_cert/ - parameters_file: parameters.yaml - network_config: network.yaml - outputs: - output_logs: logs/ - start_aggregator: - parameters: - inputs: - input_weights: additional_files/init_weights - node_cert_folder: node_cert/ - ca_cert_folder: ca_cert/ - parameters_file: parameters.yaml - network_config: network.yaml - collaborators: cols.yaml - outputs: - output_logs: logs/ - output_weights: final_weights/ diff --git a/examples/fl/mlcube/workspace/network.yaml b/examples/fl/mlcube/workspace/network.yaml deleted file mode 100644 index 31a4c1466..000000000 --- a/examples/fl/mlcube/workspace/network.yaml +++ /dev/null @@ -1,5 +0,0 @@ -agg_addr: 104.197.235.200 -agg_port: 50273 - -address: 104.197.235.200 -port: 50273 \ No newline at end of file diff --git a/examples/fl/mlcube/workspace/parameters-cpu.yaml b/examples/fl/mlcube/workspace/parameters-cpu.yaml deleted file mode 100644 index 72d04d48f..000000000 --- a/examples/fl/mlcube/workspace/parameters-cpu.yaml +++ /dev/null @@ -1,137 +0,0 @@ -plan: - aggregator: - settings: - best_state_path: save/fets_seg_test_best.pbuf - db_store_rounds: 2 - init_state_path: save/fets_seg_test_init.pbuf - last_state_path: save/fets_seg_test_last.pbuf - rounds_to_train: 2 - write_logs: true - template: openfl.component.Aggregator - assigner: - settings: - task_groups: - - name: train_and_validate - percentage: 1.0 - tasks: - - aggregated_model_validation - - train - - locally_tuned_model_validation - template: openfl.component.RandomGroupedAssigner - collaborator: - settings: - db_store_rounds: 1 - delta_updates: false - opt_treatment: RESET - template: openfl.component.Collaborator - compression_pipeline: - settings: {} - template: openfl.pipelines.NoCompressionPipeline - data_loader: - settings: - feature_shape: - - 32 - - 32 - - 32 - template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper - network: - settings: - agg_addr: any - agg_port: any - cert_folder: cert - client_reconnect_interval: 5 - disable_client_auth: false - hash_salt: auto - tls: true - template: openfl.federation.Network - task_runner: - settings: - device: cpu - gandlf_config: - batch_size: 1 - clip_grad: null - clip_mode: null - data_augmentation: {} - data_postprocessing: {} - data_preprocessing: - normalize: null - enable_padding: false - in_memory: true - inference_mechanism: - grid_aggregator_overlap: crop - patch_overlap: 0 - learning_rate: 0.001 - loss_function: dc - medcam_enabled: false - metrics: - - dice - model: - amp: true - architecture: unet - base_filters: 32 - batch_norm: false - class_list: - - 0 - - 1 - dimension: 3 - final_layer: sigmoid - ignore_label_validation: null - norm_type: instance - num_channels: 1 - nested_training: - testing: -5 - validation: -5 - num_epochs: 1 - optimizer: - type: adam - output_dir: . - parallel_compute_command: '' - patch_sampler: uniform - patch_size: - - 32 - - 32 - - 32 - patience: 1 - pin_memory_dataloader: false - print_rgb_label_warning: true - q_max_length: 1 - q_num_workers: 0 - q_samples_per_volume: 1 - q_verbose: false - save_output: false - save_training: false - scaling_factor: 1 - scheduler: - type: triangle - track_memory_usage: false - verbose: false - version: - maximum: 0.0.14 - minimum: 0.0.13 - weighted_loss: true - train_csv: seg_test_train.csv - val_csv: seg_test_val.csv - template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner - tasks: - aggregated_model_validation: - function: validate - kwargs: - apply: global - metrics: - - valid_loss - - valid_dice - locally_tuned_model_validation: - function: validate - kwargs: - apply: local - metrics: - - valid_loss - - valid_dice - settings: {} - train: - function: train - kwargs: - epochs: 1 - metrics: - - loss - - train_dice diff --git a/examples/fl/mlcube/workspace/parameters-gpu.yaml b/examples/fl/mlcube/workspace/parameters-gpu.yaml deleted file mode 100644 index b29186193..000000000 --- a/examples/fl/mlcube/workspace/parameters-gpu.yaml +++ /dev/null @@ -1,137 +0,0 @@ -plan: - aggregator: - settings: - best_state_path: save/fets_seg_test_best.pbuf - db_store_rounds: 2 - init_state_path: save/fets_seg_test_init.pbuf - last_state_path: save/fets_seg_test_last.pbuf - rounds_to_train: 2 - write_logs: true - template: openfl.component.Aggregator - assigner: - settings: - task_groups: - - name: train_and_validate - percentage: 1.0 - tasks: - - aggregated_model_validation - - train - - locally_tuned_model_validation - template: openfl.component.RandomGroupedAssigner - collaborator: - settings: - db_store_rounds: 1 - delta_updates: false - opt_treatment: RESET - template: openfl.component.Collaborator - compression_pipeline: - settings: {} - template: openfl.pipelines.NoCompressionPipeline - data_loader: - settings: - feature_shape: - - 32 - - 32 - - 32 - template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper - network: - settings: - agg_addr: any - agg_port: any - cert_folder: cert - client_reconnect_interval: 5 - disable_client_auth: false - hash_salt: auto - tls: true - template: openfl.federation.Network - task_runner: - settings: - device: cuda - gandlf_config: - batch_size: 1 - clip_grad: null - clip_mode: null - data_augmentation: {} - data_postprocessing: {} - data_preprocessing: - normalize: null - enable_padding: false - in_memory: true - inference_mechanism: - grid_aggregator_overlap: crop - patch_overlap: 0 - learning_rate: 0.001 - loss_function: dc - medcam_enabled: false - metrics: - - dice - model: - amp: true - architecture: unet - base_filters: 32 - batch_norm: false - class_list: - - 0 - - 1 - dimension: 3 - final_layer: sigmoid - ignore_label_validation: null - norm_type: instance - num_channels: 1 - nested_training: - testing: -5 - validation: -5 - num_epochs: 1 - optimizer: - type: adam - output_dir: . - parallel_compute_command: '' - patch_sampler: uniform - patch_size: - - 32 - - 32 - - 32 - patience: 1 - pin_memory_dataloader: false - print_rgb_label_warning: true - q_max_length: 1 - q_num_workers: 0 - q_samples_per_volume: 1 - q_verbose: false - save_output: false - save_training: false - scaling_factor: 1 - scheduler: - type: triangle - track_memory_usage: false - verbose: false - version: - maximum: 0.0.14 - minimum: 0.0.13 - weighted_loss: true - train_csv: seg_test_train.csv - val_csv: seg_test_val.csv - template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner - tasks: - aggregated_model_validation: - function: validate - kwargs: - apply: global - metrics: - - valid_loss - - valid_dice - locally_tuned_model_validation: - function: validate - kwargs: - apply: local - metrics: - - valid_loss - - valid_dice - settings: {} - train: - function: train - kwargs: - epochs: 1 - metrics: - - loss - - train_dice diff --git a/examples/fl/project/Dockerfile-CPU b/examples/fl/project/Dockerfile-CPU deleted file mode 100644 index 270d72d8c..000000000 --- a/examples/fl/project/Dockerfile-CPU +++ /dev/null @@ -1,31 +0,0 @@ -FROM local/openfl:local - -ENV GANDLF_VERSION 60c9d28aa5e1b951e75ed5646ac20d5790fe4317 -ENV FL_WORKSPACE /mlcube_project/fl_workspace -ENV LANG C.UTF-8 - -# install software requirements needed by GaNDLF -RUN apt-get update && apt-get upgrade -y && apt-get install -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ - -# install GaNDLF (cpu) -RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu && \ - pip install --no-cache-dir openvino-dev==2023.0.1 && \ - git clone https://github.com/mlcommons/GaNDLF.git && \ - cd GaNDLF && git checkout $GANDLF_VERSION && \ - pip install --no-cache-dir -e . - - -# install workspace requirements -COPY ./fl_workspace/requirements.txt $FL_WORKSPACE/requirements.txt -RUN pip install --no-cache-dir -r $FL_WORKSPACE/requirements.txt - -# START hotfix: patch gandlf runner -RUN rm /openfl/openfl/federated/task/runner_gandlf.py -COPY ./hotfix.py /openfl/openfl/federated/task/runner_gandlf.py -RUN pip install --no-cache-dir -e /openfl/ -# END hotfix - -# Copy mlcube workspace -COPY . /mlcube_project - -ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/project/Dockerfile-GPU b/examples/fl/project/Dockerfile-GPU deleted file mode 100644 index de67e0373..000000000 --- a/examples/fl/project/Dockerfile-GPU +++ /dev/null @@ -1,28 +0,0 @@ -FROM local/openfl:local - -ENV GANDLF_VERSION 60c9d28aa5e1b951e75ed5646ac20d5790fe4317 -ENV FL_WORKSPACE /mlcube_project/fl_workspace -ENV LANG C.UTF-8 -ENV CUDA_VISIBLE_DEVICES="0" -# TODO: combine docker images (cpu and gpu) -# TODO: make necessary changes since now user is not root - -# install software requirements needed by GaNDLF -RUN apt-get update && apt-get upgrade -y && apt-get install -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ - -# install GaNDLF (cpu) -RUN pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 && \ - pip install openvino-dev==2023.0.1 && \ - git clone https://github.com/mlcommons/GaNDLF.git && \ - cd GaNDLF && git checkout $GANDLF_VERSION && \ - pip install -e . - - -# install workspace requirements -COPY ./fl_workspace/requirements.txt $FL_WORKSPACE/requirements.txt -RUN pip install --no-cache-dir -r $FL_WORKSPACE/requirements.txt - -# Copy mlcube workspace -COPY . /mlcube_project - -ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/project/README.md b/examples/fl/project/README.md deleted file mode 100644 index ae4653520..000000000 --- a/examples/fl/project/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# How to configure container build - -- List your pip requirements in `openfl_workspace/requirements.txt` -- Modify container base image and/or how GaNDLF is installed in `dockerfile` to have GPU support. -- Modify the GaNDLF hash (or simply copy your customized GaNDLF repo code) in `dockerfile` to use a custom GaNDLF version. - -Note: the plan to be attached to the container should be GaNDLF+OpenFL plan (I guess). - -# How to build - -- Build the openfl base image: - -```bash -git clone https://github.com/securefederatedai/openfl.git -git checkout 11db12785c1a6a2d3c75656b38108443f88919e8 -cd openfl -docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . -cd .. -rm -rf openfl -``` - -- Build the MLCube - -```bash -cd ../mlcube -mlcube configure -Pdocker.build_strategy=always -``` - -# Expected assets to be attached - -(outdated) - -- cert folders: certificates of the aggregator/collaborator and the CA's public key -- collaborator list, FL plan, and the init weights for the aggregator -- training data for the collaborator - -# NOTE - -for local experiments, internal IP address or localhost will not work. Use internal fqdn. - -# For later - -- To use a plan that doesn't depend on GaNDLF, maybe `openfl_workspace/src` should be prepopulated with the necessary code. diff --git a/examples/fl/project/collaborator.py b/examples/fl/project/collaborator.py deleted file mode 100644 index d49a89483..000000000 --- a/examples/fl/project/collaborator.py +++ /dev/null @@ -1,29 +0,0 @@ -from utils import ( - get_collaborator_cn, - prepare_node_cert, - prepare_ca_cert, - prepare_plan, - prepare_data, - WORKSPACE, -) -import os -from subprocess import check_call - - -def start_collaborator( - data_path, - labels_path, - parameters_file, - node_cert_folder, - ca_cert_folder, - network_config, - output_logs, # TODO: Is it needed? -): - prepare_plan(parameters_file, network_config) - cn = get_collaborator_cn() - prepare_node_cert(node_cert_folder, "client", f"col_{cn}") - prepare_ca_cert(ca_cert_folder) - prepare_data(data_path, labels_path, cn) - - # set log files - check_call(["fx", "collaborator", "start", "-n", cn], cwd=WORKSPACE) diff --git a/examples/fl/project/fl_workspace/.workspace b/examples/fl/project/fl_workspace/.workspace deleted file mode 100644 index 3c2c5d08b..000000000 --- a/examples/fl/project/fl_workspace/.workspace +++ /dev/null @@ -1,2 +0,0 @@ -current_plan_name: default - diff --git a/examples/fl/project/fl_workspace/plan/defaults b/examples/fl/project/fl_workspace/plan/defaults deleted file mode 100644 index fb82f9c5b..000000000 --- a/examples/fl/project/fl_workspace/plan/defaults +++ /dev/null @@ -1,2 +0,0 @@ -../../workspace/plan/defaults - diff --git a/examples/fl/project/hotfix.py b/examples/fl/project/hotfix.py deleted file mode 100644 index bbcc692a6..000000000 --- a/examples/fl/project/hotfix.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""GaNDLFTaskRunner module.""" - -from copy import deepcopy - -import numpy as np -import os -import torch as pt -from typing import Union -import yaml - -from openfl.utilities.split import split_tensor_dict_for_holdouts -from openfl.utilities import TensorKey - -from .runner import TaskRunner - -from GANDLF.compute.generic import create_pytorch_objects -from GANDLF.compute.training_loop import train_network -from GANDLF.compute.forward_pass import validate_network - - -class GaNDLFTaskRunner(TaskRunner): - """GaNDLF Model class for Federated Learning.""" - - def __init__( - self, - gandlf_config: Union[str, dict] = None, - device: str = None, - **kwargs - ): - """Initialize. - Args: - device (string): Compute device (default="cpu") - **kwargs: Additional parameters to pass to the functions - """ - super().__init__(**kwargs) - - # allow pass-through of a gandlf config as a file or a dict - - train_csv = self.data_loader.train_csv - val_csv = self.data_loader.val_csv - - if isinstance(gandlf_config, str) and os.path.exists(gandlf_config): - gandlf_config = yaml.safe_load(open(gandlf_config, "r")) - - ( - model, - optimizer, - train_loader, - val_loader, - scheduler, - params, - ) = create_pytorch_objects( - gandlf_config, train_csv=train_csv, val_csv=val_csv, device=device - ) - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - self.params = params - self.device = device - - # pass the actual dataloaders to the wrapper loader - self.data_loader.set_dataloaders(train_loader, val_loader) - - self.training_round_completed = False - - self.required_tensorkeys_for_function = {} - - # FIXME: why isn't this initial call in runner_pt? - self.initialize_tensorkeys_for_functions(with_opt_vars=False) - - # overwrite attribute to account for one optimizer param (in every - # child model that does not overwrite get and set tensordict) that is - # not a numpy array - self.tensor_dict_split_fn_kwargs.update({ - 'holdout_tensor_names': ['__opt_state_needed'] - }) - - def rebuild_model(self, round_num, input_tensor_dict, validation=False): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. - Returns: - None - """ - - if self.opt_treatment == 'RESET': - self.reset_opt_vars() - self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - elif (self.training_round_completed - and self.opt_treatment == 'CONTINUE_GLOBAL' and not validation): - self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) - else: - self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - - def validate(self, col_name, round_num, input_tensor_dict, - use_tqdm=False, **kwargs): - """Validate. - Run validation of the model on the local data. - Args: - col_name: Name of the collaborator - round_num: What round is it - input_tensor_dict: Required input tensors (for model) - use_tqdm (bool): Use tqdm to print a progress bar (Default=True) - kwargs: Key word arguments passed to GaNDLF main_run - Returns: - global_output_dict: Tensors to send back to the aggregator - local_output_dict: Tensors to maintain in the local TensorDB - """ - self.rebuild_model(round_num, input_tensor_dict, validation=True) - self.model.eval() - - epoch_valid_loss, epoch_valid_metric = validate_network(self.model, - self.data_loader.val_dataloader, - self.scheduler, - self.params, - round_num, - mode="validation") - - self.logger.info(epoch_valid_loss) - self.logger.info(epoch_valid_metric) - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' - else: - suffix += '_agg' - tags = ('metric', suffix) - - output_tensor_dict = {} - valid_loss_tensor_key = TensorKey('valid_loss', origin, round_num, True, tags) - output_tensor_dict[valid_loss_tensor_key] = np.array(epoch_valid_loss) - for k, v in epoch_valid_metric.items(): - if isinstance(v, str) and "_" in v: - continue - tensor_key = TensorKey(f'valid_{k}', origin, round_num, True, tags) - output_tensor_dict[tensor_key] = np.array(v) - - # Empty list represents metrics that should only be stored locally - return output_tensor_dict, {} - - def train(self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs): - """Train batches. - Train the model on the requested number of batches. - Args: - col_name : Name of the collaborator - round_num : What round is it - input_tensor_dict : Required input tensors (for model) - use_tqdm (bool) : Use tqdm to print a progress bar (Default=True) - epochs : The number of epochs to train - crossfold_test : Whether or not to use cross fold trainval/test - to evaluate the quality of the model under fine tuning - (this uses a separate prameter to pass in the data and - config used) - crossfold_test_data_csv : Data csv used to define data used in crossfold test. - This csv does not itself define the folds, just - defines the total data to be used. - crossfold_val_n : number of folds to use for the train,val level - of the nested crossfold. - corssfold_test_n : number of folds to use for the trainval,test level - of the nested crossfold. - kwargs : Key word arguments passed to GaNDLF main_run - Returns: - global_output_dict : Tensors to send back to the aggregator - local_output_dict : Tensors to maintain in the local TensorDB - """ - self.rebuild_model(round_num, input_tensor_dict) - # set to "training" mode - self.model.train() - for epoch in range(epochs): - self.logger.info(f'Run {epoch} epoch of {round_num} round') - # FIXME: do we want to capture these in an array - # rather than simply taking the last value? - epoch_train_loss, epoch_train_metric = train_network(self.model, - self.data_loader.train_dataloader, - self.optimizer, - self.params) - - # output model tensors (Doesn't include TensorKey) - tensor_dict = self.get_tensor_dict(with_opt_vars=True) - - metric_dict = {'loss': epoch_train_loss} - for k, v in epoch_train_metric.items(): - if isinstance(v, str) and "_" in v: - continue - metric_dict[f'train_{k}'] = v - - # Return global_tensor_dict, local_tensor_dict - # is this even pt-specific really? - global_tensor_dict, local_tensor_dict = create_tensorkey_dicts( - tensor_dict, - metric_dict, - col_name, - round_num, - self.logger, - self.tensor_dict_split_fn_kwargs, - ) - - # Update the required tensors if they need to be pulled from the - # aggregator - # TODO this logic can break if different collaborators have different - # roles between rounds. - # For example, if a collaborator only performs validation in the first - # round but training in the second, it has no way of knowing the - # optimizer state tensor names to request from the aggregator because - # these are only created after training occurs. A work around could - # involve doing a single epoch of training on random data to get the - # optimizer names, and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': - self.initialize_tensorkeys_for_functions(with_opt_vars=True) - - # This will signal that the optimizer values are now present, - # and can be loaded when the model is rebuilt - self.training_round_completed = True - - # Return global_tensor_dict, local_tensor_dict - return global_tensor_dict, local_tensor_dict - - def get_tensor_dict(self, with_opt_vars=False): - """Return the tensor dictionary. - Args: - with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) - Returns: - dict: Tensor dictionary {**dict, **optimizer_dict} - """ - # Gets information regarding tensor model layers and optimizer state. - # FIXME: self.parameters() instead? Unclear if load_state_dict() or - # simple assignment is better - # for now, state dict gives us names which is good - # FIXME: do both and sanity check each time? - - state = to_cpu_numpy(self.model.state_dict()) - - if with_opt_vars: - opt_state = _get_optimizer_state(self.optimizer) - state = {**state, **opt_state} - - return state - - def _get_weights_names(self, with_opt_vars=False): - # Gets information regarding tensor model layers and optimizer state. - # FIXME: self.parameters() instead? Unclear if load_state_dict() or - # simple assignment is better - # for now, state dict gives us names which is good - # FIXME: do both and sanity check each time? - - state = self.model.state_dict().keys() - - if with_opt_vars: - opt_state = _get_optimizer_state(self.model.optimizer) - state += opt_state.keys() - - return state - - def set_tensor_dict(self, tensor_dict, with_opt_vars=False): - """Set the tensor dictionary. - Args: - tensor_dict: The tensor dictionary - with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) - """ - set_pt_model_from_tensor_dict(self.model, tensor_dict, self.device, with_opt_vars) - - def get_optimizer(self): - """Get the optimizer of this instance.""" - return self.optimizer - - def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called \ - as part of a task. By default, this is just all of the layers and \ - optimizer of the model. - Args: - func_name - Returns: - list : [TensorKey] - """ - if func_name == 'validate': - local_model = 'apply=' + str(kwargs['apply']) - return self.required_tensorkeys_for_function[func_name][local_model] - else: - return self.required_tensorkeys_for_function[func_name] - - def initialize_tensorkeys_for_functions(self, with_opt_vars=False): - """Set the required tensors for all publicly accessible task methods. - By default, this is just all of the layers and optimizer of the model. - Custom tensors should be added to this function. - Args: - None - Returns: - None - """ - # TODO there should be a way to programmatically iterate through - # all of the methods in the class and declare the tensors. - # For now this is done manually - - output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) - if not with_opt_vars: - global_model_dict_val = global_model_dict - local_model_dict_val = local_model_dict - else: - output_model_dict = self.get_tensor_dict(with_opt_vars=False) - global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( - self.logger, - output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) - - self.required_tensorkeys_for_function['train'] = [ - TensorKey( - tensor_name, 'GLOBAL', 0, False, ('model',) - ) for tensor_name in global_model_dict - ] - self.required_tensorkeys_for_function['train'] += [ - TensorKey( - tensor_name, 'LOCAL', 0, False, ('model',) - ) for tensor_name in local_model_dict - ] - - # Validation may be performed on local or aggregated (global) model, - # so there is an extra lookup dimension for kwargs - self.required_tensorkeys_for_function['validate'] = {} - # TODO This is not stateless. The optimizer will not be - self.required_tensorkeys_for_function['validate']['apply=local'] = [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) - for tensor_name in { - **global_model_dict_val, - **local_model_dict_val - }] - self.required_tensorkeys_for_function['validate']['apply=global'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in global_model_dict_val - ] - self.required_tensorkeys_for_function['validate']['apply=global'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in local_model_dict_val - ] - - def load_native(self, filepath, model_state_dict_key='model_state_dict', - optimizer_state_dict_key='optimizer_state_dict', **kwargs): - """ - Load model and optimizer states from a pickled file specified by \ - filepath. model_/optimizer_state_dict args can be specified if needed. \ - Uses pt.load(). - Args: - filepath (string) : Path to pickle file created - by pt.save(). - model_state_dict_key (string) : key for model state dict - in pickled file. - optimizer_state_dict_key (string) : key for optimizer state dict - in picked file. - kwargs : unused - Returns: - None - """ - pickle_dict = pt.load(filepath) - self.model.load_state_dict(pickle_dict[model_state_dict_key]) - self.optimizer.load_state_dict(pickle_dict[optimizer_state_dict_key]) - - def save_native(self, filepath, model_state_dict_key='model_state_dict', - optimizer_state_dict_key='optimizer_state_dict', **kwargs): - """ - Save model and optimizer states in a picked file specified by the \ - filepath. model_/optimizer_state_dicts are stored in the keys provided. \ - Uses pt.save(). - Args: - filepath (string) : Path to pickle file to be - created by pt.save(). - model_state_dict_key (string) : key for model state dict - in pickled file. - optimizer_state_dict_key (string) : key for optimizer state - dict in picked file. - kwargs : unused - Returns: - None - """ - pickle_dict = { - model_state_dict_key: self.model.state_dict(), - optimizer_state_dict_key: self.optimizer.state_dict() - } - pt.save(pickle_dict, filepath) - - def reset_opt_vars(self): - """ - Reset optimizer variables. - Resets the optimizer variables - """ - pass - - -def create_tensorkey_dicts(tensor_dict, - metric_dict, - col_name, - round_num, - logger, - tensor_dict_split_fn_kwargs): - origin = col_name - tags = ('trained',) - output_metric_dict = {} - for k, v in metric_dict.items(): - tk = TensorKey(k, origin, round_num, True, ('metric',)) - output_metric_dict[tk] = np.array(v) - - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - logger, tensor_dict, **tensor_dict_split_fn_kwargs - ) - - # Create global tensorkeys - global_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in global_model_dict.items() - } - # Create tensorkeys that should stay local - local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in local_model_dict.items() - } - # The train/validate aggregated function of the next round will look - # for the updated model parameters. - # This ensures they will be resolved locally - next_local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num + 1, False, ('model',)): nparray - for tensor_name, nparray in local_model_dict.items()} - - global_tensor_dict = { - **output_metric_dict, - **global_tensorkey_model_dict - } - local_tensor_dict = { - **local_tensorkey_model_dict, - **next_local_tensorkey_model_dict - } - - return global_tensor_dict, local_tensor_dict - - -def set_pt_model_from_tensor_dict(model, tensor_dict, device, with_opt_vars=False): - """Set the tensor dictionary. - Args: - model: the pytorch nn.module object - tensor_dict: The tensor dictionary - device: the device where the tensor values need to be sent - with_opt_vars (bool): Return the tensor dictionary including the - optimizer tensors (Default=False) - """ - # Sets tensors for model layers and optimizer state. - # FIXME: model.parameters() instead? Unclear if load_state_dict() or - # simple assignment is better - # for now, state dict gives us names, which is good - # FIXME: do both and sanity check each time? - - new_state = {} - # Grabbing keys from model's state_dict helps to confirm we have - # everything - for k in model.state_dict(): - new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device) - - # set model state - model.load_state_dict(new_state) - - if with_opt_vars: - # see if there is state to restore first - if tensor_dict.pop('__opt_state_needed') == 'true': - _set_optimizer_state(model.get_optimizer(), device, tensor_dict) - - # sanity check that we did not record any state that was not used - assert len(tensor_dict) == 0 - - -def _derive_opt_state_dict(opt_state_dict): - """Separate optimizer tensors from the tensor dictionary. - Flattens the optimizer state dict so as to have key, value pairs with - values as numpy arrays. - The keys have sufficient info to restore opt_state_dict using - expand_derived_opt_state_dict. - Args: - opt_state_dict: The optimizer state dictionary - """ - derived_opt_state_dict = {} - - # Determine if state is needed for this optimizer. - if len(opt_state_dict['state']) == 0: - derived_opt_state_dict['__opt_state_needed'] = 'false' - return derived_opt_state_dict - - derived_opt_state_dict['__opt_state_needed'] = 'true' - - # Using one example state key, we collect keys for the corresponding - # dictionary value. - example_state_key = opt_state_dict['param_groups'][0]['params'][0] - example_state_subkeys = set( - opt_state_dict['state'][example_state_key].keys() - ) - - # We assume that the state collected for all params in all param groups is - # the same. - # We also assume that whether or not the associated values to these state - # subkeys is a tensor depends only on the subkey. - # Using assert statements to break the routine if these assumptions are - # incorrect. - for state_key in opt_state_dict['state'].keys(): - assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys()) - for state_subkey in example_state_subkeys: - assert (isinstance( - opt_state_dict['state'][example_state_key][state_subkey], - pt.Tensor) - == isinstance( - opt_state_dict['state'][state_key][state_subkey], - pt.Tensor)) - - state_subkeys = list(opt_state_dict['state'][example_state_key].keys()) - - # Tags will record whether the value associated to the subkey is a - # tensor or not. - state_subkey_tags = [] - for state_subkey in state_subkeys: - if isinstance( - opt_state_dict['state'][example_state_key][state_subkey], - pt.Tensor - ): - state_subkey_tags.append('istensor') - else: - state_subkey_tags.append('') - state_subkeys_and_tags = list(zip(state_subkeys, state_subkey_tags)) - - # Forming the flattened dict, using a concatenation of group index, - # subindex, tag, and subkey inserted into the flattened dict key - - # needed for reconstruction. - nb_params_per_group = [] - for group_idx, group in enumerate(opt_state_dict['param_groups']): - for idx, param_id in enumerate(group['params']): - for subkey, tag in state_subkeys_and_tags: - if tag == 'istensor': - new_v = opt_state_dict['state'][param_id][ - subkey].cpu().numpy() - else: - new_v = np.array( - [opt_state_dict['state'][param_id][subkey]] - ) - derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v - nb_params_per_group.append(idx + 1) - # group lengths are also helpful for reconstructing - # original opt_state_dict structure - derived_opt_state_dict['__opt_group_lengths'] = np.array( - nb_params_per_group - ) - - return derived_opt_state_dict - - -def expand_derived_opt_state_dict(derived_opt_state_dict, device): - """Expand the optimizer state dictionary. - Takes a derived opt_state_dict and creates an opt_state_dict suitable as - input for load_state_dict for restoring optimizer state. - Reconstructing state_subkeys_and_tags using the example key - prefix, "__opt_state_0_0_", certain to be present. - Args: - derived_opt_state_dict: Optimizer state dictionary - Returns: - dict: Optimizer state dictionary - """ - state_subkeys_and_tags = [] - for key in derived_opt_state_dict: - if key.startswith('__opt_state_0_0_'): - stripped_key = key[16:] - if stripped_key.startswith('istensor_'): - this_tag = 'istensor' - subkey = stripped_key[9:] - else: - this_tag = '' - subkey = stripped_key[1:] - state_subkeys_and_tags.append((subkey, this_tag)) - - opt_state_dict = {'param_groups': [], 'state': {}} - nb_params_per_group = list( - derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32) - ) - - # Construct the expanded dict. - for group_idx, nb_params in enumerate(nb_params_per_group): - these_group_ids = [f'{group_idx}_{idx}' for idx in range(nb_params)] - opt_state_dict['param_groups'].append({'params': these_group_ids}) - for this_id in these_group_ids: - opt_state_dict['state'][this_id] = {} - for subkey, tag in state_subkeys_and_tags: - flat_key = f'__opt_state_{this_id}_{tag}_{subkey}' - if tag == 'istensor': - new_v = pt.from_numpy(derived_opt_state_dict.pop(flat_key)) - else: - # Here (for currrently supported optimizers) the subkey - # should be 'step' and the length of array should be one. - assert subkey == 'step' - assert len(derived_opt_state_dict[flat_key]) == 1 - new_v = int(derived_opt_state_dict.pop(flat_key)) - opt_state_dict['state'][this_id][subkey] = new_v - - # sanity check that we did not miss any optimizer state - assert len(derived_opt_state_dict) == 0 - - return opt_state_dict - - -def _get_optimizer_state(optimizer): - """Return the optimizer state. - Args: - optimizer - """ - opt_state_dict = deepcopy(optimizer.state_dict()) - - # Optimizer state might not have some parts representing frozen parameters - # So we do not synchronize them - param_keys_with_state = set(opt_state_dict['state'].keys()) - for group in opt_state_dict['param_groups']: - local_param_set = set(group['params']) - params_to_sync = local_param_set & param_keys_with_state - group['params'] = sorted(params_to_sync) - - derived_opt_state_dict = _derive_opt_state_dict(opt_state_dict) - - return derived_opt_state_dict - - -def _set_optimizer_state(optimizer, device, derived_opt_state_dict): - """Set the optimizer state. - Args: - optimizer: - device: - derived_opt_state_dict: - """ - temp_state_dict = expand_derived_opt_state_dict( - derived_opt_state_dict, device) - - # FIXME: Figure out whether or not this breaks learning rate - # scheduling and the like. - # Setting default values. - # All optimizer.defaults are considered as not changing over course of - # training. - for group in temp_state_dict['param_groups']: - for k, v in optimizer.defaults.items(): - group[k] = v - - optimizer.load_state_dict(temp_state_dict) - - -def to_cpu_numpy(state): - """Send data to CPU as Numpy array. - Args: - state - """ - # deep copy so as to decouple from active model - state = deepcopy(state) - - for k, v in state.items(): - # When restoring, we currently assume all values are tensors. - if not pt.is_tensor(v): - raise ValueError('We do not currently support non-tensors ' - 'coming from model.state_dict()') - # get as a numpy array, making sure is on cpu - state[k] = v.cpu().numpy() - return state diff --git a/examples/fl/project/mlcube.py b/examples/fl/project/mlcube.py deleted file mode 100644 index 6ee9e2de3..000000000 --- a/examples/fl/project/mlcube.py +++ /dev/null @@ -1,34 +0,0 @@ -"""MLCube handler file""" -import argparse -from collaborator import start_collaborator -from aggregator import start_aggregator - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - - train = subparsers.add_parser("train") - train.add_argument("--data_path", metavar="", type=str, required=True) - train.add_argument("--labels_path", metavar="", type=str, required=True) - train.add_argument("--parameters_file", metavar="", type=str, required=True) - train.add_argument("--node_cert_folder", metavar="", type=str, required=True) - train.add_argument("--ca_cert_folder", metavar="", type=str, required=True) - train.add_argument("--network_config", metavar="", type=str, required=True) - train.add_argument("--output_logs", metavar="", type=str, required=True) - - agg = subparsers.add_parser("start_aggregator") - agg.add_argument("--input_weights", metavar="", type=str, required=True) - agg.add_argument("--parameters_file", metavar="", type=str, required=True) - agg.add_argument("--node_cert_folder", metavar="", type=str, required=True) - agg.add_argument("--ca_cert_folder", metavar="", type=str, required=True) - agg.add_argument("--output_logs", metavar="", type=str, required=True) - agg.add_argument("--output_weights", metavar="", type=str, required=True) - agg.add_argument("--network_config", metavar="", type=str, required=True) - agg.add_argument("--collaborators", metavar="", type=str, required=True) - - args = parser.parse_args() - if hasattr(args, "data_path"): - start_collaborator(**vars(args)) - else: - start_aggregator(**vars(args)) From 2ff534d6bd8b98a11c67b6e0e95c884098f273c1 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 12 Mar 2024 03:13:14 +0100 Subject: [PATCH 018/242] update cli training tests --- cli/cli_tests_training.sh | 45 +++++++++++++++++---------------------- cli/tests_setup.sh | 33 ++++++++++++++-------------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 842137b33..76eaf0ba3 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -5,7 +5,6 @@ ################### Start Testing ######################## ########################################################## - ########################################################## echo "==========================================" echo "Creating test profiles for each user" @@ -133,12 +132,10 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -HOSTNAME=$(hostname -A | cut -d " " -f 1) -medperf aggregator submit -n aggreg -a $HOSTNAME -p 50273 +HOSTNAME_=$(hostname -A | cut -d " " -f 1) +medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 checkFailed "aggregator submission step failed" -medperf aggregator ls -medperf aggregator ls | grep aggreg | tr -s ' ' | cut -d ' ' -f 1 -AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | cut -d ' ' -f 1) +AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -169,7 +166,7 @@ echo "Running data1 preparation step" echo "=====================================" medperf dataset create -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" checkFailed "Data1 preparation step failed" -DSET_1_GENUID=$(medperf dataset ls | grep col1 | tr -s ' ' | cut -d ' ' -f 1) +DSET_1_GENUID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -180,7 +177,7 @@ echo "Running data1 submission step" echo "=====================================" medperf dataset submit -d $DSET_1_GENUID -y checkFailed "Data1 submission step failed" -DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | cut -d ' ' -f 1) +DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -189,7 +186,7 @@ echo "\n" echo "=====================================" echo "Running data1 association step" echo "=====================================" -medperf dataset associate -d $DSET_1_UID -t $TRAINING_UID -y +medperf training associate_dataset -d $DSET_1_UID -t $TRAINING_UID -y checkFailed "Data1 association step failed" ########################################################## @@ -211,7 +208,7 @@ echo "Running data2 preparation step" echo "=====================================" medperf dataset create -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" checkFailed "Data2 preparation step failed" -DSET_2_GENUID=$(medperf dataset ls | grep col2 | tr -s ' ' | cut -d ' ' -f 1) +DSET_2_GENUID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -222,7 +219,7 @@ echo "Running data2 submission step" echo "=====================================" medperf dataset submit -d $DSET_2_GENUID -y checkFailed "Data2 submission step failed" -DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | cut -d ' ' -f 1) +DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -231,7 +228,7 @@ echo "\n" echo "=====================================" echo "Running data2 association step" echo "=====================================" -medperf dataset associate -d $DSET_2_UID -t $TRAINING_UID -y +medperf training associate_dataset -d $DSET_2_UID -t $TRAINING_UID -y checkFailed "Data2 association step failed" ########################################################## @@ -251,7 +248,7 @@ echo "\n" echo "=====================================" echo "Approve aggregator association" echo "=====================================" -medperf association approve -t $TRAINING_UID -a $AGG_UID +medperf training approve_association -t $TRAINING_UID -a $AGG_UID checkFailed "agg association approval failed" ########################################################## @@ -261,7 +258,7 @@ echo "\n" echo "=====================================" echo "Approve data1 association" echo "=====================================" -medperf association approve -t $TRAINING_UID -d $DSET_1_UID +medperf training approve_association -t $TRAINING_UID -d $DSET_1_UID checkFailed "data1 association approval failed" ########################################################## @@ -271,7 +268,7 @@ echo "\n" echo "=====================================" echo "Approve data2 association" echo "=====================================" -medperf association approve -t $TRAINING_UID -d $DSET_2_UID +medperf training approve_association -t $TRAINING_UID -d $DSET_2_UID checkFailed "data2 association approval failed" ########################################################## @@ -302,15 +299,14 @@ echo "=====================================" echo "Starting aggregator" echo "=====================================" RUNCOMMAND="medperf aggregator start -a $AGG_UID -t $TRAINING_UID" -nohup $RUNCOMMAND < /dev/null &>agg.log & +nohup $RUNCOMMAND agg.log & # sleep so that the mlcube is run before we change profiles sleep 7 -AGG_PID=$(ps -ef | grep $RUNCOMMAND | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) +AGG_PID=$(ps -ef | grep "$RUNCOMMAND" | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) # Check if the command is still running. -if ! kill -0 "$AGG_PID" &> /dev/null; -then +if ! kill -0 "$AGG_PID" &>/dev/null; then checkFailed "agg doesn't seem to be running" 1 fi ########################################################## @@ -332,15 +328,14 @@ echo "=====================================" echo "Starting training with data1" echo "=====================================" RUNCOMMAND="medperf training run -d $DSET_1_UID -t $TRAINING_UID" -nohup $RUNCOMMAND < /dev/null &>col1.log & +nohup $RUNCOMMAND col1.log & # sleep so that the mlcube is run before we change profiles sleep 7 -DATA1_PID=$(ps -ef | grep $RUNCOMMAND | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) +DATA1_PID=$(ps -ef | grep "$RUNCOMMAND" | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) # Check if the command is still running. -if ! kill -0 "$DATA1_PID" &> /dev/null; -then +if ! kill -0 "$DATA1_PID" &>/dev/null; then checkFailed "data1 training doesn't seem to be running" 1 fi ########################################################## @@ -372,10 +367,10 @@ echo "=====================================" echo "Waiting for other prcocesses to exit successfully" echo "=====================================" # NOTE: on systems with small process ID table or very short-lived processes, -# there is a probability that PIDs are reused and hence the +# there is a probability that PIDs are reused and hence the # code below may be inaccurate. Perhaps grep processes according to command # string is the most efficient way to reduce that probability further. -# Followup NOTE: not sure, but the "wait" command may fail if it is waiting for +# Followup NOTE: not sure, but the "wait" command may fail if it is waiting for # a process that is not a child of the current shell wait $DATA1_PID checkFailed "data1 training didn't exit successfully" diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index db6e23b58..af4a30c71 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -1,13 +1,12 @@ #! /bin/bash -while getopts s:d:c:ft: flag -do - case "${flag}" in - s) SERVER_URL=${OPTARG};; - d) DIRECTORY=${OPTARG};; - c) CLEANUP="true";; - f) FRESH="true";; - t) TIMEOUT=${OPTARG};; - esac +while getopts s:d:c:ft: flag; do + case "${flag}" in + s) SERVER_URL=${OPTARG} ;; + d) DIRECTORY=${OPTARG} ;; + c) CLEANUP="true" ;; + f) FRESH="true" ;; + t) TIMEOUT=${OPTARG} ;; + esac done SERVER_URL="${SERVER_URL:-https://localhost:8000}" @@ -26,7 +25,7 @@ echo "Server URL: $SERVER_URL" echo "Storage location: $MEDPERF_SUBSTORAGE" # frequently used -clean(){ +clean() { echo "=====================================" echo "Cleaning up medperf tmp files" echo "=====================================" @@ -38,8 +37,11 @@ clean(){ medperf profile delete testbenchmark medperf profile delete testmodel medperf profile delete testdata + medperf profile delete testagg + medperf profile delete testdata1 + medperf profile delete testdata2 } -checkFailed(){ +checkFailed() { EXITSTATUS="$?" if [ -n "$2" ]; then EXITSTATUS="1" @@ -58,7 +60,6 @@ checkFailed(){ fi } - if ${FRESH}; then clean fi @@ -76,7 +77,7 @@ DEMO_URL="${ASSETS_URL}/assets/datasets/demo_dset1.tar.gz" # prep cubes PREP_MLCUBE="$ASSETS_URL/prep-sep/mlcube/mlcube.yaml" PREP_PARAMS="$ASSETS_URL/prep-sep/mlcube/workspace/parameters.yaml" -PREP_TRAINING_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube_prep.yaml" +PREP_TRAINING_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube.yaml" # model cubes FAILING_MODEL_MLCUBE="$ASSETS_URL/model-bug/mlcube/mlcube.yaml" # doesn't fail with association @@ -99,8 +100,8 @@ METRIC_MLCUBE="$ASSETS_URL/metrics/mlcube/mlcube.yaml" METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" # FL cubes -TRAIN_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube-cpu.yaml?v=2" -TRAIN_PARAMS="https://storage.googleapis.com/medperf-storage/testfl/parameters-miccai.yaml" +TRAIN_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube1.0.0.yaml" +TRAIN_PARAMS="https://storage.googleapis.com/medperf-storage/testfl/parameters1.0.0.yaml" TRAIN_WEIGHTS="https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz" # test users credentials @@ -109,4 +110,4 @@ DATAOWNER="testdo@example.com" BENCHMARKOWNER="testbo@example.com" ADMIN="testadmin@example.com" DATAOWNER2="testdo2@example.com" -AGGOWNER="testao@example.com" \ No newline at end of file +AGGOWNER="testao@example.com" From ee8d19548845df6c2d7cc9a2bc1e5ae49ad48863 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 12 Mar 2024 03:13:36 +0100 Subject: [PATCH 019/242] fix bugs --- cli/medperf/commands/training/run.py | 2 +- cli/medperf/config.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py index d0648726c..f47904d91 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/training/run.py @@ -92,7 +92,7 @@ def run_experiment(self): env_dict = {"COLLABORATOR_CN": dataset_cn} # just for now create some output folders (TODO) - out_logs = os.path.join(self.training_exp.path, "data_logs") + out_logs = os.path.join(self.training_exp.path, "data_logs", dataset_cn) os.makedirs(out_logs, exist_ok=True) params = { diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 98d91b36f..b37b016b4 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -130,8 +130,6 @@ "logs_folder", "tmp_folder", "demo_datasets_folder", - "training_folder", - "aggregators_folder", ] server_folders = [ "benchmarks_folder", @@ -141,6 +139,8 @@ "results_folder", "predictions_folder", "tests_folder", + "training_folder", + "aggregators_folder", ] # MedPerf filenames conventions From 9f2935acd52e5f3e5d1e2d8f3080826189192fa5 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 21 Mar 2024 01:30:54 +0100 Subject: [PATCH 020/242] TMP TO BE REVERTED --- cli/medperf/cli.py | 5 +++-- cli/medperf/commands/mlcube/mlcube.py | 25 +++++++++++++++++++++++++ cli/medperf/commands/mlcube/run.py | 11 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 cli/medperf/commands/mlcube/run.py diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 3d4d1ff52..a6ab8caba 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -19,7 +19,8 @@ import medperf.commands.training.training as training import medperf.commands.aggregator.aggregator as aggregator import medperf.commands.storage as storage -from medperf.utils import check_for_updates + +# from medperf.utils import check_for_updates app = typer.Typer() app.add_typer(mlcube.app, name="mlcube", help="Manage mlcubes") @@ -104,6 +105,6 @@ def main( logging.info(f"Running MedPerf v{__version__} on {loglevel} logging level") logging.info(f"Executed command: {' '.join(sys.argv[1:])}") - check_for_updates() + # check_for_updates() config.ui.print(f"MedPerf {__version__}") diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index e5b0253ee..7605db45b 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -9,10 +9,35 @@ from medperf.commands.mlcube.create import CreateCube from medperf.commands.mlcube.submit import SubmitCube from medperf.commands.mlcube.associate import AssociateCube +from medperf.commands.mlcube.run import run_mlcube app = typer.Typer() +@app.command("run") +@clean_except +def run( + mlcube_path: str = typer.Option( + ..., "--mlcube", "-m", help="path to mlcube folder" + ), + task: str = typer.Option(..., "--task", "-t", help="mlcube task to run"), + out_logs: str = typer.Option( + None, "--output-logs", "-o", help="where to store stdout" + ), + port: str = typer.Option(None, "--port", "-P", help="port to expose"), + env: str = typer.Option( + "", "--env", "-e", help="comma separated list of key=value pairs" + ), + params: str = typer.Option( + "", "--params", "-p", help="comma separated list of key=value pairs" + ), +): + """List mlcubes stored locally and remotely from the user""" + params = dict([p.split("=") for p in params.strip().strip(",").split(",") if p]) + env = dict([p.split("=") for p in env.strip().strip(",").split(",") if p]) + run_mlcube(mlcube_path, task, out_logs, params, port, env) + + @app.command("ls") @clean_except def list( diff --git a/cli/medperf/commands/mlcube/run.py b/cli/medperf/commands/mlcube/run.py new file mode 100644 index 000000000..b239b4672 --- /dev/null +++ b/cli/medperf/commands/mlcube/run.py @@ -0,0 +1,11 @@ +from medperf.tests.mocks.cube import TestCube +import os +from medperf import config + + +def run_mlcube(mlcube_path, task, out_logs, params, port, env): + c = TestCube() + c.path = mlcube_path + c.cube_path = os.path.join(c.path, config.cube_filename) + c.params_path = os.path.join(c.path, config.params_filename) + c.run(task, out_logs, port=port, env_dict=env, **params) From 294fd845f6fd63f642fc557f259b7aa365278673 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 25 Mar 2024 16:40:38 +0100 Subject: [PATCH 021/242] multiple datasets per collaborator cert --- examples/fl/fl/README.md | 2 +- examples/fl/fl/clean.sh | 1 + examples/fl/fl/project/utils.py | 10 +++++++++- examples/fl/fl/setup_clean.sh | 1 + examples/fl/fl/setup_test.sh | 25 ++++++++++++++++++------- examples/fl/fl/sync.sh | 2 ++ examples/fl/fl/test.sh | 5 +++-- 7 files changed, 35 insertions(+), 11 deletions(-) diff --git a/examples/fl/fl/README.md b/examples/fl/fl/README.md index d4228439d..918f483e3 100644 --- a/examples/fl/fl/README.md +++ b/examples/fl/fl/README.md @@ -1,6 +1,6 @@ # How to run tests - Run `setup_test.sh` just once to create certs and download required data. -- Run `test.sh` to start the aggregator and two collaborators. +- Run `test.sh` to start the aggregator and three collaborators. - Run `clean.sh` to be able to rerun `test.sh` freshly. - Run `setup_clean.sh` to clear what has been generated in step 1. diff --git a/examples/fl/fl/clean.sh b/examples/fl/fl/clean.sh index f7806d151..cc7bdc725 100644 --- a/examples/fl/fl/clean.sh +++ b/examples/fl/fl/clean.sh @@ -2,3 +2,4 @@ rm -rf mlcube_agg/workspace/final_weights rm -rf mlcube_agg/workspace/logs rm -rf mlcube_col1/workspace/logs rm -rf mlcube_col2/workspace/logs +rm -rf mlcube_col3/workspace/logs diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py index 9cd42fa25..a7d2a0851 100644 --- a/examples/fl/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -56,13 +56,21 @@ def prepare_plan(parameters_file, network_config, fl_workspace): def prepare_cols_list(collaborators_file, fl_workspace): with open(collaborators_file) as f: cols = f.read().strip().split("\n") + cols = [col.strip().split(",") for col in cols] + cols_dict = {} + for col in cols: + if len(col) == 1: + cols_dict[col[0]] = col[0] + else: + assert len(col) == 2 + cols_dict[col[0]] = col[1] target_plan_folder = os.path.join(fl_workspace, "plan") # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "cols.yaml") with open(target_plan_file, "w") as f: - yaml.dump({"collaborators": cols}, f) + yaml.dump({"collaborators": cols_dict}, f) def prepare_init_weights(input_weights, fl_workspace): diff --git a/examples/fl/fl/setup_clean.sh b/examples/fl/fl/setup_clean.sh index 18ca5536c..9f9242024 100644 --- a/examples/fl/fl/setup_clean.sh +++ b/examples/fl/fl/setup_clean.sh @@ -1,4 +1,5 @@ rm -rf ./mlcube_agg rm -rf ./mlcube_col1 rm -rf ./mlcube_col2 +rm -rf ./mlcube_col3 rm -rf ./ca diff --git a/examples/fl/fl/setup_test.sh b/examples/fl/fl/setup_test.sh index 0285787e3..f728883b3 100644 --- a/examples/fl/fl/setup_test.sh +++ b/examples/fl/fl/setup_test.sh @@ -1,10 +1,12 @@ cp -r ./mlcube ./mlcube_agg cp -r ./mlcube ./mlcube_col1 cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert mkdir ./ca HOSTNAME_=$(hostname -A | cut -d " " -f 1) @@ -15,8 +17,8 @@ openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root. -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" # col1 -sed -i '/^commonName = /c\commonName = col1' csr.conf -sed -i '/^DNS\.1 = /c\DNS.1 = col1' csr.conf +sed -i '/^commonName = /c\commonName = col1@example.com' csr.conf +sed -i '/^DNS\.1 = /c\DNS.1 = col1@example.com' csr.conf cd mlcube_col1/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client @@ -27,9 +29,13 @@ cp ../../../ca/root.crt ../ca_cert/ cd ../../../ # col2 -sed -i '/^commonName = /c\commonName = col2' csr.conf -sed -i '/^DNS\.1 = /c\DNS.1 = col2' csr.conf -cd mlcube_col2/workspace/node_cert +cp mlcube_col1/workspace/node_cert/* mlcube_col2/workspace/node_cert +cp mlcube_col1/workspace/ca_cert/* mlcube_col2/workspace/ca_cert + +# col3 +sed -i '/^commonName = /c\commonName = col3@example.com' csr.conf +sed -i '/^DNS\.1 = /c\DNS.1 = col3@example.com' csr.conf +cd mlcube_col3/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ @@ -58,10 +64,12 @@ echo "port: 50273" >>mlcube_col1/workspace/network.yaml cp mlcube_col1/workspace/network.yaml mlcube_col2/workspace/network.yaml cp mlcube_col1/workspace/network.yaml mlcube_agg/workspace/network.yaml +cp mlcube_col1/workspace/network.yaml mlcube_col3/workspace/network.yaml # cols file -echo "col1" >>mlcube_agg/workspace/cols.yaml -echo "col2" >>mlcube_agg/workspace/cols.yaml +echo "abc,col1@example.com" >>mlcube_agg/workspace/cols.yaml +echo "defg,col1@example.com" >>mlcube_agg/workspace/cols.yaml +echo "hij,col3@example.com" >>mlcube_agg/workspace/cols.yaml # data download cd mlcube_col1/workspace/ @@ -76,6 +84,9 @@ tar -xf col2_prepared.tar.gz rm col2_prepared.tar.gz cd ../.. +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace + # weights download cd mlcube_agg/workspace/ mkdir additional_files diff --git a/examples/fl/fl/sync.sh b/examples/fl/fl/sync.sh index 1a36e5ab5..d454a9c2f 100644 --- a/examples/fl/fl/sync.sh +++ b/examples/fl/fl/sync.sh @@ -1,7 +1,9 @@ cp mlcube/workspace/parameters.yaml mlcube_agg/workspace/parameters.yaml cp mlcube/workspace/parameters.yaml mlcube_col1/workspace/parameters.yaml cp mlcube/workspace/parameters.yaml mlcube_col2/workspace/parameters.yaml +cp mlcube/workspace/parameters.yaml mlcube_col3/workspace/parameters.yaml cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col2/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col3/mlcube.yaml diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh index f05b2551a..1b751ce37 100644 --- a/examples/fl/fl/test.sh +++ b/examples/fl/fl/test.sh @@ -1,3 +1,4 @@ gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273; bash" -gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col1 --task train -e COLLABORATOR_CN=col1; bash" -gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col2 --task train -e COLLABORATOR_CN=col2; bash" +gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col1 --task train -e COLLABORATOR_CN=abc; bash" +gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col2 --task train -e COLLABORATOR_CN=defg; bash" +gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col3 --task train -e COLLABORATOR_CN=hij; bash" From 710652a6840caff9e60bd938710d1c816f1c0832 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 26 Mar 2024 03:52:15 +0100 Subject: [PATCH 022/242] fix training tests --- cli/cli_tests_training.sh | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 76eaf0ba3..22adcfa15 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -298,15 +298,14 @@ echo "\n" echo "=====================================" echo "Starting aggregator" echo "=====================================" -RUNCOMMAND="medperf aggregator start -a $AGG_UID -t $TRAINING_UID" -nohup $RUNCOMMAND agg.log & +AGGCOMMAND="medperf aggregator start -a $AGG_UID -t $TRAINING_UID" +nohup $AGGCOMMAND agg.log & # sleep so that the mlcube is run before we change profiles sleep 7 -AGG_PID=$(ps -ef | grep "$RUNCOMMAND" | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) # Check if the command is still running. -if ! kill -0 "$AGG_PID" &>/dev/null; then +if [ -z $(pgrep -xf "$AGGCOMMAND") ]; then checkFailed "agg doesn't seem to be running" 1 fi ########################################################## @@ -327,15 +326,14 @@ echo "\n" echo "=====================================" echo "Starting training with data1" echo "=====================================" -RUNCOMMAND="medperf training run -d $DSET_1_UID -t $TRAINING_UID" -nohup $RUNCOMMAND col1.log & +DATA1COMMAND="medperf training run -d $DSET_1_UID -t $TRAINING_UID" +nohup $DATA1COMMAND col1.log & # sleep so that the mlcube is run before we change profiles sleep 7 -DATA1_PID=$(ps -ef | grep "$RUNCOMMAND" | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) # Check if the command is still running. -if ! kill -0 "$DATA1_PID" &>/dev/null; then +if [ -z $(pgrep -xf "$DATA1COMMAND") ]; then checkFailed "data1 training doesn't seem to be running" 1 fi ########################################################## @@ -372,9 +370,9 @@ echo "=====================================" # string is the most efficient way to reduce that probability further. # Followup NOTE: not sure, but the "wait" command may fail if it is waiting for # a process that is not a child of the current shell -wait $DATA1_PID +wait $(pgrep -xf "$DATA1COMMAND") checkFailed "data1 training didn't exit successfully" -wait $AGG_PID +wait $(pgrep -xf "$AGGCOMMAND") checkFailed "aggregator didn't exit successfully" ########################################################## From b3e4e52b30c341c79c0e0512933d268bd3487498 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 26 Mar 2024 03:56:56 +0100 Subject: [PATCH 023/242] empty From 0181e47f7323ddac1a8ea9eae2c47b812bc443e7 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 26 Mar 2024 04:54:36 +0100 Subject: [PATCH 024/242] post-merging-main fixes --- cli/cli_tests_training.sh | 46 +++++++++++++++++++--------- cli/medperf/entities/training_exp.py | 7 +++-- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 22adcfa15..b3a6c3c28 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -90,11 +90,11 @@ echo "=====================================" echo "Submit cubes" echo "=====================================" -medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE +medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE --operational checkFailed "Train prep submission failed" PREP_UID=$(medperf mlcube ls | grep trainprep | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) -medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -p $TRAIN_PARAMS -a $TRAIN_WEIGHTS +medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -p $TRAIN_PARAMS -a $TRAIN_WEIGHTS --operational checkFailed "traincube submission failed" TRAINCUBE_UID=$(medperf mlcube ls | grep traincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) ########################################################## @@ -160,24 +160,33 @@ checkFailed "testdata1 profile activation failed" echo "\n" +########################################################## +echo "=====================================" +echo "Running data1 submission step" +echo "=====================================" +medperf dataset submit -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" -y +checkFailed "Data1 submission step failed" +DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Running data1 preparation step" echo "=====================================" -medperf dataset create -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" +medperf dataset prepare -d $DSET_1_UID checkFailed "Data1 preparation step failed" -DSET_1_GENUID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Running data1 submission step" +echo "Running data1 set_operational step" echo "=====================================" -medperf dataset submit -d $DSET_1_GENUID -y -checkFailed "Data1 submission step failed" -DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +medperf dataset set_operational -d $DSET_1_UID -y +checkFailed "Data1 set_operational step failed" ########################################################## echo "\n" @@ -202,24 +211,33 @@ checkFailed "testdata2 profile activation failed" echo "\n" +########################################################## +echo "=====================================" +echo "Running data2 submission step" +echo "=====================================" +medperf dataset submit -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" -y +checkFailed "Data2 submission step failed" +DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Running data2 preparation step" echo "=====================================" -medperf dataset create -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" +medperf dataset prepare -d $DSET_2_UID checkFailed "Data2 preparation step failed" -DSET_2_GENUID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Running data2 submission step" +echo "Running data2 set_operational step" echo "=====================================" -medperf dataset submit -d $DSET_2_GENUID -y -checkFailed "Data2 submission step failed" -DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +medperf dataset set_operational -d $DSET_2_UID -y +checkFailed "Data2 set_operational step failed" ########################################################## echo "\n" diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index 6d4e3ee04..50f050361 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -10,7 +10,7 @@ from medperf.utils import get_dataset_common_name from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema -from medperf.account_management import get_medperf_user_data, read_user_account +from medperf.account_management import get_medperf_user_data class TrainingExp( @@ -37,7 +37,6 @@ class TrainingExp( datasets: List[int] = None metadata: dict = {} user_metadata: dict = {} - state: str = "DEVELOPMENT" @validator("datasets", pre=True, always=True) def set_default_datasets_value(cls, value, values, **kwargs): @@ -216,7 +215,9 @@ def __get_local_dict(cls, training_exp_uid) -> dict: dict: information of the training_exp """ logging.info(f"Retrieving training_exp {training_exp_uid} from local storage") - training_exp_storage = os.path.join(config.training_folder, str(training_exp_uid)) + training_exp_storage = os.path.join( + config.training_folder, str(training_exp_uid) + ) training_exp_file = os.path.join( training_exp_storage, config.training_exps_filename ) From ff3c4ad1b390a26e678ffabfb751bc2b714cd37e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 26 Mar 2024 05:25:08 +0100 Subject: [PATCH 025/242] debug --- cli/cli_tests_training.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index b3a6c3c28..80d957592 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -323,7 +323,7 @@ nohup $AGGCOMMAND agg.log & sleep 7 # Check if the command is still running. -if [ -z $(pgrep -xf "$AGGCOMMAND") ]; then +if [ -z $(pgrep -f "$AGGCOMMAND") ]; then checkFailed "agg doesn't seem to be running" 1 fi ########################################################## @@ -351,7 +351,7 @@ nohup $DATA1COMMAND col1.log & sleep 7 # Check if the command is still running. -if [ -z $(pgrep -xf "$DATA1COMMAND") ]; then +if [ -z $(pgrep -f "$DATA1COMMAND") ]; then checkFailed "data1 training doesn't seem to be running" 1 fi ########################################################## @@ -388,9 +388,9 @@ echo "=====================================" # string is the most efficient way to reduce that probability further. # Followup NOTE: not sure, but the "wait" command may fail if it is waiting for # a process that is not a child of the current shell -wait $(pgrep -xf "$DATA1COMMAND") +wait $(pgrep -f "$DATA1COMMAND") checkFailed "data1 training didn't exit successfully" -wait $(pgrep -xf "$AGGCOMMAND") +wait $(pgrep -f "$AGGCOMMAND") checkFailed "aggregator didn't exit successfully" ########################################################## From 4d4c53ea8097125341059a6547571d5d9a97f165 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 2 Apr 2024 04:14:43 +0200 Subject: [PATCH 026/242] fix testing util --- cli/tests_setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index 7086bfd5d..b3bed9f29 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -64,7 +64,7 @@ checkSucceeded() { if [ "$?" -eq 0 ]; then i_am_a_command_that_does_not_exist_and_hence_changes_the_last_exit_status_to_nonzero fi - checkFailed $1 + checkFailed "$1" } if ${FRESH}; then From 03ebfbbb67e8e831cb70d413bfcfbc7dcbd60c7a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 2 Apr 2024 05:37:20 +0200 Subject: [PATCH 027/242] fix training tests --- cli/cli_tests_training.sh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 80d957592..463780e31 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -316,14 +316,14 @@ echo "\n" echo "=====================================" echo "Starting aggregator" echo "=====================================" -AGGCOMMAND="medperf aggregator start -a $AGG_UID -t $TRAINING_UID" -nohup $AGGCOMMAND agg.log & +medperf aggregator start -a $AGG_UID -t $TRAINING_UID agg.log 2>&1 & +AGG_PID=$! # sleep so that the mlcube is run before we change profiles sleep 7 # Check if the command is still running. -if [ -z $(pgrep -f "$AGGCOMMAND") ]; then +if [ ! -d "/proc/$AGG_PID" ]; then checkFailed "agg doesn't seem to be running" 1 fi ########################################################## @@ -344,14 +344,14 @@ echo "\n" echo "=====================================" echo "Starting training with data1" echo "=====================================" -DATA1COMMAND="medperf training run -d $DSET_1_UID -t $TRAINING_UID" -nohup $DATA1COMMAND col1.log & +medperf training run -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & +COL1_PID=$! # sleep so that the mlcube is run before we change profiles sleep 7 # Check if the command is still running. -if [ -z $(pgrep -f "$DATA1COMMAND") ]; then +if [ ! -d "/proc/$COL1_PID" ]; then checkFailed "data1 training doesn't seem to be running" 1 fi ########################################################## @@ -388,9 +388,9 @@ echo "=====================================" # string is the most efficient way to reduce that probability further. # Followup NOTE: not sure, but the "wait" command may fail if it is waiting for # a process that is not a child of the current shell -wait $(pgrep -f "$DATA1COMMAND") +wait $COL1_PID checkFailed "data1 training didn't exit successfully" -wait $(pgrep -f "$AGGCOMMAND") +wait $AGG_PID checkFailed "aggregator didn't exit successfully" ########################################################## From 20ec858ea57afcc1e183c69e3d01b65ec9b4e9bc Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 2 Apr 2024 07:44:31 +0200 Subject: [PATCH 028/242] use ip address for aggregtor --- cli/cli_tests_training.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 463780e31..dfda1f4af 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -132,7 +132,7 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -HOSTNAME_=$(hostname -A | cut -d " " -f 1) +HOSTNAME_=$(hostname -I | cut -d " " -f 1) medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) From f3910b8156d44a107c305d3215b2d9a8eeb6e766 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Apr 2024 16:22:00 +0200 Subject: [PATCH 029/242] aggregator changes --- server/aggregator/migrations/0001_initial.py | 31 -------------------- server/aggregator/models.py | 11 +++++-- server/aggregator/serializers.py | 2 +- server/aggregator/urls.py | 5 +--- 4 files changed, 11 insertions(+), 38 deletions(-) delete mode 100644 server/aggregator/migrations/0001_initial.py diff --git a/server/aggregator/migrations/0001_initial.py b/server/aggregator/migrations/0001_initial.py deleted file mode 100644 index e5548f0fb..000000000 --- a/server/aggregator/migrations/0001_initial.py +++ /dev/null @@ -1,31 +0,0 @@ -# Generated by Django 3.2.20 on 2023-09-29 01:02 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name='Aggregator', - fields=[ - ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('name', models.CharField(max_length=20, unique=True)), - ('server_config', models.JSONField(blank=True, default=dict, null=True)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('modified_at', models.DateTimeField(auto_now=True)), - ('owner', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), - ], - options={ - 'ordering': ['created_at'], - }, - ), - ] diff --git a/server/aggregator/models.py b/server/aggregator/models.py index 96efd6485..683e9441f 100644 --- a/server/aggregator/models.py +++ b/server/aggregator/models.py @@ -7,12 +7,19 @@ class Aggregator(models.Model): owner = models.ForeignKey(User, on_delete=models.PROTECT) name = models.CharField(max_length=20, unique=True) - server_config = models.JSONField(default=dict, blank=True, null=True) + address = models.CharField(max_length=300) + port = models.IntegerField() + aggregation_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="aggregators", + ) + metadata = models.JSONField(default=dict, blank=True, null=True) created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) def __str__(self): - return self.server_config + return self.address class Meta: ordering = ["created_at"] diff --git a/server/aggregator/serializers.py b/server/aggregator/serializers.py index acfd0726f..89eabfbbb 100644 --- a/server/aggregator/serializers.py +++ b/server/aggregator/serializers.py @@ -6,4 +6,4 @@ class AggregatorSerializer(serializers.ModelSerializer): class Meta: model = Aggregator fields = "__all__" - read_only_fields = ["owner"] \ No newline at end of file + read_only_fields = ["owner"] diff --git a/server/aggregator/urls.py b/server/aggregator/urls.py index 0c86c9197..08ef74148 100644 --- a/server/aggregator/urls.py +++ b/server/aggregator/urls.py @@ -1,12 +1,9 @@ from django.urls import path from . import views -from aggregator_association import views as aviews app_name = "aggregator" urlpatterns = [ path("", views.AggregatorList.as_view()), path("/", views.AggregatorDetail.as_view()), - path("training_experiments/", aviews.ExperimentAggregatorList.as_view()), - path("/training_experiments//", aviews.AggregatorApproval.as_view()), -] \ No newline at end of file +] From 36c4ddcafc0973c00a253d42151b814acaa1172b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Apr 2024 16:24:24 +0200 Subject: [PATCH 030/242] add ca server entity --- server/ca/__init__.py | 0 server/ca/admin.py | 3 ++ server/ca/apps.py | 6 ++++ server/ca/migrations/__init__.py | 0 server/ca/models.py | 21 +++++++++++++ server/ca/serializers.py | 9 ++++++ server/ca/urls.py | 9 ++++++ server/ca/views.py | 52 ++++++++++++++++++++++++++++++++ 8 files changed, 100 insertions(+) create mode 100644 server/ca/__init__.py create mode 100644 server/ca/admin.py create mode 100644 server/ca/apps.py create mode 100644 server/ca/migrations/__init__.py create mode 100644 server/ca/models.py create mode 100644 server/ca/serializers.py create mode 100644 server/ca/urls.py create mode 100644 server/ca/views.py diff --git a/server/ca/__init__.py b/server/ca/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca/admin.py b/server/ca/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/ca/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/ca/apps.py b/server/ca/apps.py new file mode 100644 index 000000000..5bfe9cfe8 --- /dev/null +++ b/server/ca/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CaConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "ca" diff --git a/server/ca/migrations/__init__.py b/server/ca/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca/models.py b/server/ca/models.py new file mode 100644 index 000000000..322c7b098 --- /dev/null +++ b/server/ca/models.py @@ -0,0 +1,21 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class CA(models.Model): + owner = models.ForeignKey(User, on_delete=models.PROTECT) + name = models.CharField(max_length=20, unique=True) + address = models.CharField(max_length=300) + port = models.IntegerField() + fingerprint = models.TextField() + metadata = models.JSONField(default=dict, blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return self.address + + class Meta: + ordering = ["created_at"] diff --git a/server/ca/serializers.py b/server/ca/serializers.py new file mode 100644 index 000000000..d693058da --- /dev/null +++ b/server/ca/serializers.py @@ -0,0 +1,9 @@ +from rest_framework import serializers +from .models import CA + + +class CASerializer(serializers.ModelSerializer): + class Meta: + model = CA + fields = "__all__" + read_only_fields = ["owner"] diff --git a/server/ca/urls.py b/server/ca/urls.py new file mode 100644 index 000000000..70515aaf6 --- /dev/null +++ b/server/ca/urls.py @@ -0,0 +1,9 @@ +from django.urls import path +from . import views + +app_name = "ca" + +urlpatterns = [ + path("", views.CAList.as_view()), + path("/", views.CADetail.as_view()), +] diff --git a/server/ca/views.py b/server/ca/views.py new file mode 100644 index 000000000..8a2bc6036 --- /dev/null +++ b/server/ca/views.py @@ -0,0 +1,52 @@ +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .models import CA +from .serializers import CASerializer +from drf_spectacular.utils import extend_schema + + +class CAList(GenericAPIView): + serializer_class = CASerializer + queryset = "" + + @extend_schema(operation_id="cas_retrieve_all") + def get(self, request, format=None): + """ + List all cas + """ + cas = CA.objects.all() + cas = self.paginate_queryset(cas) + serializer = CASerializer(cas, many=True) + return self.get_paginated_response(serializer.data) + + def post(self, request, format=None): + """ + Create a new CA + """ + serializer = CASerializer(data=request.data) + if serializer.is_valid(): + serializer.save(owner=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class CADetail(GenericAPIView): + serializer_class = CASerializer + queryset = "" + + def get_object(self, pk): + try: + return CA.objects.get(pk=pk) + except CA.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve an ca instance. + """ + ca = self.get_object(pk) + serializer = CASerializer(ca) + return Response(serializer.data) From 687f8e7f0518559bde2b42ba6d3fc9c8f5f7fa56 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Apr 2024 16:25:41 +0200 Subject: [PATCH 031/242] associations refactoring --- server/benchmarkdataset/serializers.py | 56 ++++++++++---------------- server/benchmarkmodel/serializers.py | 56 ++++++++++---------------- server/utils/associations.py | 39 ++++++++++++++++++ 3 files changed, 82 insertions(+), 69 deletions(-) create mode 100644 server/utils/associations.py diff --git a/server/benchmarkdataset/serializers.py b/server/benchmarkdataset/serializers.py index 9cc120079..fa7a16cdc 100644 --- a/server/benchmarkdataset/serializers.py +++ b/server/benchmarkdataset/serializers.py @@ -4,6 +4,10 @@ from dataset.models import Dataset from .models import BenchmarkDataset +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) class BenchmarkDatasetListSerializer(serializers.ModelSerializer): @@ -12,61 +16,48 @@ class Meta: read_only_fields = ["initiated_by", "approved_at"] fields = "__all__" - def __validate_approval_status(self, last_benchmarkdataset, approval_status): - if not last_benchmarkdataset: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_benchmarkdataset.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_benchmarkdataset.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) - def validate(self, data): bid = self.context["request"].data.get("benchmark") dataset = self.context["request"].data.get("dataset") approval_status = self.context["request"].data.get("approval_status", "PENDING") + + # benchmark state benchmark = Benchmark.objects.get(pk=bid) benchmark_state = benchmark.state if benchmark_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational benchmark" ) + + # benchmark approval status benchmark_approval_status = benchmark.approval_status if benchmark_approval_status != "APPROVED": raise serializers.ValidationError( "Association requests can be made only on an approved benchmark" ) + + # dataset state dataset_obj = Dataset.objects.get(pk=dataset) dataset_state = dataset_obj.state if dataset_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational dataset" ) + + # dataset prep mlcube if dataset_obj.data_preparation_mlcube != benchmark.data_preparation_mlcube: raise serializers.ValidationError( "Dataset association request can be made only if the dataset" " was prepared with benchmark's data preparation MLCube" ) + + # approval status last_benchmarkdataset = ( BenchmarkDataset.objects.filter(benchmark__id=bid, dataset__id=dataset) .order_by("-created_at") .first() ) - self.__validate_approval_status(last_benchmarkdataset, approval_status) + validate_approval_status_on_creation(last_benchmarkdataset, approval_status) return data @@ -75,10 +66,11 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( + same_owner = ( validated_data["dataset"].owner.id == validated_data["benchmark"].owner.id - ): + ) + if same_owner: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() return BenchmarkDataset.objects.create(**validated_data) @@ -103,17 +95,11 @@ def validate(self, data): def validate_approval_status(self, cur_approval_status): last_approval_status = self.instance.approval_status - if last_approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject only a pending request" - ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if cur_approval_status == "APPROVED": - if current_user.id == initiated_user.id: - raise serializers.ValidationError( - "Same user cannot approve the association request" - ) + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) return cur_approval_status def update(self, instance, validated_data): diff --git a/server/benchmarkmodel/serializers.py b/server/benchmarkmodel/serializers.py index afa34acd4..57cd1e2ab 100644 --- a/server/benchmarkmodel/serializers.py +++ b/server/benchmarkmodel/serializers.py @@ -4,6 +4,10 @@ from mlcube.models import MlCube from .models import BenchmarkModel +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) class BenchmarkModelListSerializer(serializers.ModelSerializer): @@ -16,49 +20,38 @@ def validate(self, data): bid = self.context["request"].data.get("benchmark") mlcube = self.context["request"].data.get("model_mlcube") approval_status = self.context["request"].data.get("approval_status", "PENDING") + + # benchmark state benchmark = Benchmark.objects.get(pk=bid) benchmark_state = benchmark.state if benchmark_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational benchmark" ) + + # benchmark approval status benchmark_approval_status = benchmark.approval_status if benchmark_approval_status != "APPROVED": raise serializers.ValidationError( "Association requests can be made only on an approved benchmark" ) - mlcube_state = MlCube.objects.get(pk=mlcube).state + + # mlcube state + mlcube_obj = MlCube.objects.get(pk=mlcube) + mlcube_state = mlcube_obj.state if mlcube_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational model mlcube" ) + + # approval status last_benchmarkmodel = ( BenchmarkModel.objects.filter(benchmark__id=bid, model_mlcube__id=mlcube) .order_by("-created_at") .first() ) - if not last_benchmarkmodel: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_benchmarkmodel.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - # check valid results passed - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_benchmarkmodel.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) + validate_approval_status_on_creation(last_benchmarkmodel, approval_status) + return data def create(self, validated_data): @@ -66,10 +59,11 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( + same_owner = ( validated_data["model_mlcube"].owner.id == validated_data["benchmark"].owner.id - ): + ) + if same_owner: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() return BenchmarkModel.objects.create(**validated_data) @@ -95,17 +89,11 @@ def validate(self, data): def validate_approval_status(self, cur_approval_status): last_approval_status = self.instance.approval_status - if last_approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject only a pending request" - ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if cur_approval_status == "APPROVED": - if current_user.id == initiated_user.id: - raise serializers.ValidationError( - "Same user cannot approve the association request" - ) + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) return cur_approval_status def update(self, instance, validated_data): diff --git a/server/utils/associations.py b/server/utils/associations.py new file mode 100644 index 000000000..54014a82f --- /dev/null +++ b/server/utils/associations.py @@ -0,0 +1,39 @@ +from rest_framework import serializers + + +def validate_approval_status_on_creation(last_association, approval_status): + if not last_association: + if approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject association request only if there are prior requests" + ) + else: + if approval_status == "PENDING": + if last_association.approval_status != "REJECTED": + raise serializers.ValidationError( + "User can create a new request only if prior request is rejected" + ) + elif approval_status == "APPROVED": + raise serializers.ValidationError( + "User cannot create an approved association request" + ) + # approval_status == "REJECTED": + else: + if last_association.approval_status != "APPROVED": + raise serializers.ValidationError( + "User can reject request only if prior request is approved" + ) + + +def validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user +): + if last_approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject only a pending request" + ) + if cur_approval_status == "APPROVED": + if current_user.id == initiated_user.id: + raise serializers.ValidationError( + "Same user cannot approve the association request" + ) From c7883b5b60f7e7cbb8493e262d31edb6c2a996da Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Apr 2024 16:26:11 +0200 Subject: [PATCH 032/242] remove signing code --- server/key_storage/__init__.py | 0 server/key_storage/gcloud_secret_manager.py | 11 -- server/key_storage/local.py | 20 --- server/signing/__init__.py | 0 server/signing/cryptography/__init__.py | 3 - server/signing/cryptography/ca.py | 150 -------------------- server/signing/cryptography/io.py | 129 ----------------- server/signing/cryptography/participant.py | 72 ---------- server/signing/cryptography/utils.py | 14 -- server/signing/interface.py | 67 --------- 10 files changed, 466 deletions(-) delete mode 100644 server/key_storage/__init__.py delete mode 100644 server/key_storage/gcloud_secret_manager.py delete mode 100644 server/key_storage/local.py delete mode 100644 server/signing/__init__.py delete mode 100644 server/signing/cryptography/__init__.py delete mode 100644 server/signing/cryptography/ca.py delete mode 100644 server/signing/cryptography/io.py delete mode 100644 server/signing/cryptography/participant.py delete mode 100644 server/signing/cryptography/utils.py delete mode 100644 server/signing/interface.py diff --git a/server/key_storage/__init__.py b/server/key_storage/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/server/key_storage/gcloud_secret_manager.py b/server/key_storage/gcloud_secret_manager.py deleted file mode 100644 index d36edd069..000000000 --- a/server/key_storage/gcloud_secret_manager.py +++ /dev/null @@ -1,11 +0,0 @@ -class GcloudSecretStorage: - def __init__(self, filepath): - raise NotImplementedError - - def write(self, key, storage_id): - # NOTE: use one secret per deployment. - # store keys as secret versions - raise NotImplementedError - - def read(self, storage_id): - raise NotImplementedError diff --git a/server/key_storage/local.py b/server/key_storage/local.py deleted file mode 100644 index 00cfe60c3..000000000 --- a/server/key_storage/local.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -from signing.cryptography.io import write_key, read_key - - -class LocalSecretStorage: - """NOT SUITABLE FOR PRODUCTION. it simply stores keys - in filesystem.""" - - def __init__(self, folderpath): - os.makedirs(folderpath, exist_ok=True) - self.folderpath = folderpath - - def write(self, key, storage_id): - filepath = os.path.join(self.folderpath, storage_id) - write_key(key, filepath) - - def read(self, storage_id): - filepath = os.path.join(self.folderpath, storage_id) - key = read_key(filepath) - return key diff --git a/server/signing/__init__.py b/server/signing/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/server/signing/cryptography/__init__.py b/server/signing/cryptography/__init__.py deleted file mode 100644 index b3f394d12..000000000 --- a/server/signing/cryptography/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -"""openfl.cryptography package.""" diff --git a/server/signing/cryptography/ca.py b/server/signing/cryptography/ca.py deleted file mode 100644 index d651919a4..000000000 --- a/server/signing/cryptography/ca.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Cryptography CA utilities.""" - -import datetime -import uuid -from typing import Tuple - -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.x509.base import Certificate -from cryptography.x509.base import CertificateSigningRequest -from cryptography.x509.extensions import ExtensionNotFound -from cryptography.x509.name import Name -from cryptography.x509.oid import ExtensionOID -from cryptography.x509.oid import NameOID - - -def generate_root_cert( - common_name: str = "Simple Root CA", days_to_expiration: int = 365 -) -> Tuple[RSAPrivateKey, Certificate]: - """Generate_root_certificate.""" - now = datetime.datetime.utcnow() - expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) - - # Generate private key - root_private_key = rsa.generate_private_key( - public_exponent=65537, key_size=3072, backend=default_backend() - ) - - # Generate public key - root_public_key = root_private_key.public_key() - builder = x509.CertificateBuilder() - subject = x509.Name( - [ - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), - x509.NameAttribute(NameOID.COMMON_NAME, common_name), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Root CA"), - ] - ) - issuer = subject - builder = builder.subject_name(subject) - builder = builder.issuer_name(issuer) - - builder = builder.not_valid_before(now) - builder = builder.not_valid_after(now + expiration_delta) - builder = builder.serial_number(int(uuid.uuid4())) - builder = builder.public_key(root_public_key) - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - - # Sign the CSR - certificate = builder.sign( - private_key=root_private_key, - algorithm=hashes.SHA384(), - backend=default_backend(), - ) - - return root_private_key, certificate - - -def generate_signing_csr() -> Tuple[RSAPrivateKey, CertificateSigningRequest]: - """Generate signing CSR.""" - # Generate private key - signing_private_key = rsa.generate_private_key( - public_exponent=65537, key_size=3072, backend=default_backend() - ) - - builder = x509.CertificateSigningRequestBuilder() - subject = x509.Name( - [ - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), - x509.NameAttribute(NameOID.COMMON_NAME, "Simple Signing CA"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Signing CA"), - ] - ) - builder = builder.subject_name(subject) - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - - # Sign the CSR - csr = builder.sign( - private_key=signing_private_key, - algorithm=hashes.SHA384(), - backend=default_backend(), - ) - - return signing_private_key, csr - - -def sign_certificate( - csr: CertificateSigningRequest, - issuer_private_key: RSAPrivateKey, - issuer_name: Name, - days_to_expiration: int = 365, - ca: bool = False, -) -> Certificate: - """ - Sign the incoming CSR request. - - Args: - csr : Certificate Signing Request object - issuer_private_key : Root CA private key if the request is for the signing - CA; Signing CA private key otherwise - issuer_name : x509 Name - days_to_expiration : int (365 days by default) - ca : Is this a certificate authority - """ - now = datetime.datetime.utcnow() - expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) - - builder = x509.CertificateBuilder() - builder = builder.subject_name(csr.subject) - builder = builder.issuer_name(issuer_name) - builder = builder.not_valid_before(now) - builder = builder.not_valid_after(now + expiration_delta) - builder = builder.serial_number(int(uuid.uuid4())) - builder = builder.public_key(csr.public_key()) - builder = builder.add_extension( - x509.BasicConstraints(ca=ca, path_length=None), - critical=True, - ) - try: - builder = builder.add_extension( - csr.extensions.get_extension_for_oid( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME - ).value, - critical=False, - ) - except ExtensionNotFound: - pass # Might not have alternative name - - signed_cert = builder.sign( - private_key=issuer_private_key, - algorithm=hashes.SHA384(), - backend=default_backend(), - ) - return signed_cert diff --git a/server/signing/cryptography/io.py b/server/signing/cryptography/io.py deleted file mode 100644 index 52bfc5e95..000000000 --- a/server/signing/cryptography/io.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Cryptography IO utilities.""" - -import os -from hashlib import sha384 -from pathlib import Path -from typing import Tuple - -from cryptography import x509 -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.hazmat.primitives.serialization import load_pem_private_key -from cryptography.x509.base import Certificate -from cryptography.x509.base import CertificateSigningRequest - - -def read_key(path: Path) -> RSAPrivateKey: - """ - Read private key. - - Args: - path : Path (pathlib) - - Returns: - private_key - """ - with open(path, 'rb') as f: - pem_data = f.read() - - signing_key = load_pem_private_key(pem_data, password=None) - # TODO: replace assert with exception / sys.exit - assert (isinstance(signing_key, rsa.RSAPrivateKey)) - return signing_key - - -def write_key(key: RSAPrivateKey, path: Path) -> None: - """ - Write private key. - - Args: - key : RSA private key object - path : Path (pathlib) - - """ - def key_opener(path, flags): - return os.open(path, flags, mode=0o600) - - with open(path, 'wb', opener=key_opener) as f: - f.write(key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) - - -def read_crt(path: Path) -> Certificate: - """ - Read signed TLS certificate. - - Args: - path : Path (pathlib) - - Returns: - Cryptography TLS Certificate object - """ - with open(path, 'rb') as f: - pem_data = f.read() - - certificate = x509.load_pem_x509_certificate(pem_data) - # TODO: replace assert with exception / sys.exit - assert (isinstance(certificate, x509.Certificate)) - return certificate - - -def write_crt(certificate: Certificate, path: Path) -> None: - """ - Write cryptography certificate / csr. - - Args: - certificate : cryptography csr / certificate object - path : Path (pathlib) - - Returns: - Cryptography TLS Certificate object - """ - with open(path, 'wb') as f: - f.write(certificate.public_bytes( - encoding=serialization.Encoding.PEM, - )) - - -def read_csr(path: Path) -> Tuple[CertificateSigningRequest, str]: - """ - Read certificate signing request. - - Args: - path : Path (pathlib) - - Returns: - Cryptography CSR object - """ - with open(path, 'rb') as f: - pem_data = f.read() - - csr = x509.load_pem_x509_csr(pem_data) - # TODO: replace assert with exception / sys.exit - assert (isinstance(csr, x509.CertificateSigningRequest)) - return csr, get_csr_hash(csr) - - -def get_csr_hash(certificate: CertificateSigningRequest) -> str: - """ - Get hash of cryptography certificate. - - Args: - certificate : Cryptography CSR object - - Returns: - Hash of cryptography certificate / csr - """ - hasher = sha384() - encoded_bytes = certificate.public_bytes( - encoding=serialization.Encoding.PEM, - ) - hasher.update(encoded_bytes) - return hasher.hexdigest() diff --git a/server/signing/cryptography/participant.py b/server/signing/cryptography/participant.py deleted file mode 100644 index d6e94712b..000000000 --- a/server/signing/cryptography/participant.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Cryptography participant utilities.""" -from typing import Tuple - -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.x509.base import CertificateSigningRequest -from cryptography.x509.oid import NameOID - - -def generate_csr(common_name: str, - server: bool = False) -> Tuple[RSAPrivateKey, CertificateSigningRequest]: - """Issue certificate signing request for server and client.""" - # Generate private key - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=3072, - backend=default_backend() - ) - - builder = x509.CertificateSigningRequestBuilder() - subject = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, common_name), - ]) - builder = builder.subject_name(subject) - builder = builder.add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=True, - ) - if server: - builder = builder.add_extension( - x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.SERVER_AUTH]), - critical=True - ) - - else: - builder = builder.add_extension( - x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.CLIENT_AUTH]), - critical=True - ) - - builder = builder.add_extension( - x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - data_encipherment=False, - key_agreement=False, - content_commitment=False, - key_cert_sign=False, - crl_sign=False, - encipher_only=False, - decipher_only=False - ), - critical=True - ) - - builder = builder.add_extension( - x509.SubjectAlternativeName([x509.DNSName(common_name)]), - critical=False - ) - - # Sign the CSR - csr = builder.sign( - private_key=private_key, algorithm=hashes.SHA384(), - backend=default_backend() - ) - - return private_key, csr diff --git a/server/signing/cryptography/utils.py b/server/signing/cryptography/utils.py deleted file mode 100644 index 03f9eb940..000000000 --- a/server/signing/cryptography/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from cryptography.hazmat.primitives import serialization -from cryptography import x509 - - -def cert_to_str(cert): - return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") - - -def str_to_cert(cert_str): - return x509.load_pem_x509_certificate(cert_str.encode("utf-8")) - - -def str_to_csr(csr_str): - return x509.load_pem_x509_csr(csr_str.encode("utf-8")) diff --git a/server/signing/interface.py b/server/signing/interface.py deleted file mode 100644 index 024931da7..000000000 --- a/server/signing/interface.py +++ /dev/null @@ -1,67 +0,0 @@ -from cryptography import x509 - -from django.conf import settings -from .cryptography.ca import generate_root_cert, sign_certificate -from .cryptography.utils import cert_to_str, str_to_cert, str_to_csr -from training.models import TrainingExperiment - - -def __get_experiment_key_pair(training_exp_id): - exp = TrainingExperiment.objects.get(pk=training_exp_id) - private_key_id = exp.private_key - private_key = settings.KEY_STORAGE.read(private_key_id) - public_key_str = exp.public_key - public_key = str_to_cert(public_key_str) - return private_key, public_key - - -def generate_key_pair(training_exp_id): - # TODO: do we need to destroy the keys at some point? - ca_common_name = f"training_{training_exp_id}" - root_private_key, certificate = generate_root_cert(ca_common_name) - - # store private key - storage_id = ca_common_name - settings.KEY_STORAGE.write(root_private_key, storage_id) - - # public key to str - public_key_str = cert_to_str(certificate) - return storage_id, public_key_str - - -def sign_csr(csr_str, training_exp_id): - # Load CSR - csr = str_to_csr(csr_str) - - # load signing key and crt - signing_key, signing_crt = __get_experiment_key_pair(training_exp_id) - - # sign - signed_cert = sign_certificate(csr, signing_key, signing_crt.subject) - - # cert as str - cert_str = cert_to_str(signed_cert) - - return cert_str - - -def verify_dataset_csr(csr_str, dataset_object, training_exp): - # TODO? - try: - csr = str_to_csr(csr_str) - except ValueError as e: - return False, str(e) - if not isinstance(csr, x509.CertificateSigningRequest): - return False, "Invalid CSR format" - return True, "" - - -def verify_aggregator_csr(csr_str, aggregator_object, training_exp, request): - # TODO? - try: - csr = str_to_csr(csr_str) - except ValueError as e: - return False, str(e) - if not isinstance(csr, x509.CertificateSigningRequest): - return False, "Invalid CSR format" - return True, "" From 1b356cb336fd57e92c230a04ba3f5ec56a2cfa3c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 22 Apr 2024 16:48:37 +0200 Subject: [PATCH 033/242] add training event object --- server/trainingevent/__init__.py | 0 server/trainingevent/admin.py | 3 ++ server/trainingevent/apps.py | 6 +++ server/trainingevent/migrations/__init__.py | 0 server/trainingevent/models.py | 17 ++++++ server/trainingevent/permissions.py | 50 +++++++++++++++++ server/trainingevent/serializers.py | 45 ++++++++++++++++ server/trainingevent/tests.py | 3 ++ server/trainingevent/views.py | 59 +++++++++++++++++++++ 9 files changed, 183 insertions(+) create mode 100644 server/trainingevent/__init__.py create mode 100644 server/trainingevent/admin.py create mode 100644 server/trainingevent/apps.py create mode 100644 server/trainingevent/migrations/__init__.py create mode 100644 server/trainingevent/models.py create mode 100644 server/trainingevent/permissions.py create mode 100644 server/trainingevent/serializers.py create mode 100644 server/trainingevent/tests.py create mode 100644 server/trainingevent/views.py diff --git a/server/trainingevent/__init__.py b/server/trainingevent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/trainingevent/admin.py b/server/trainingevent/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/trainingevent/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/trainingevent/apps.py b/server/trainingevent/apps.py new file mode 100644 index 000000000..9d1295a0f --- /dev/null +++ b/server/trainingevent/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class TrainingeventConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "trainingevent" diff --git a/server/trainingevent/migrations/__init__.py b/server/trainingevent/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/trainingevent/models.py b/server/trainingevent/models.py new file mode 100644 index 000000000..61376d77d --- /dev/null +++ b/server/trainingevent/models.py @@ -0,0 +1,17 @@ +from django.db import models +from training.models import TrainingExperiment + + +# Create your models here. +class TrainingEvent(models.Model): + finished = models.BooleanField(default=False) + training_exp = models.ForeignKey( + TrainingExperiment, on_delete=models.PROTECT, related_name="events" + ) + participants = models.JSONField() + report = models.JSONField(blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + finished_at = models.DateTimeField(null=True, blank=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/trainingevent/permissions.py b/server/trainingevent/permissions.py new file mode 100644 index 000000000..f0eb1ee94 --- /dev/null +++ b/server/trainingevent/permissions.py @@ -0,0 +1,50 @@ +from rest_framework.permissions import BasePermission +from .models import TrainingExperiment + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsExpOwner(BasePermission): + def get_object(self, tid): + try: + return TrainingExperiment.objects.get(pk=tid) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + tid = view.kwargs.get("tid", None) + if not tid: + return False + training_exp = self.get_object(tid) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False + + +class IsAggregatorOwner(BasePermission): + def get_object(self, tid): + try: + return TrainingExperiment.objects.get(pk=tid) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + tid = view.kwargs.get("tid", None) + if not tid: + return False + training_exp = self.get_object(tid) + if not training_exp: + return False + aggregator = training_exp.aggregator + if not aggregator: + return False + if aggregator.owner.id == request.user.id: + return True + else: + return False diff --git a/server/trainingevent/serializers.py b/server/trainingevent/serializers.py new file mode 100644 index 000000000..7f3974077 --- /dev/null +++ b/server/trainingevent/serializers.py @@ -0,0 +1,45 @@ +from rest_framework import serializers +from .models import TrainingEvent +from training.models import TrainingExperiment +from django.utils import timezone + + +class EventSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingEvent + fields = "__all__" + read_only_fields = ["finished", "finished_at", "report"] + + def validate(self, data): + training_exp = TrainingExperiment.objects.get(pk=data["training_exp"]) + if training_exp.approval_status != "APPROVED": + raise serializers.ValidationError( + "User cannot create an event unless the experiment is approved" + ) + prev_event = training_exp.event + if prev_event and not training_exp.event.finished: + raise serializers.ValidationError( + "User cannot create a new event unless the previous event has finished" + ) + + return data + + +class EventDetailSerializer(serializers.ModelSerializer): + class Meta: + model = TrainingEvent + fields = "__all__" + read_only_fields = ["finished_at", "training_exp", "participants", "finished"] + + def validate(self, data): + if self.instance.finished: + raise serializers.ValidationError("User cannot edit a finished event") + return data + + def update(self, instance, validated_data): + if "report" in validated_data: + instance.report = validated_data["report"] + instance.finished = True + instance.finished_at = timezone.now() + instance.save() + return instance diff --git a/server/trainingevent/tests.py b/server/trainingevent/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/server/trainingevent/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/server/trainingevent/views.py b/server/trainingevent/views.py new file mode 100644 index 000000000..a55d5a6e7 --- /dev/null +++ b/server/trainingevent/views.py @@ -0,0 +1,59 @@ +from .models import TrainingEvent +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsExpOwner, IsAggregatorOwner +from .serializers import EventSerializer, EventDetailSerializer + + +class EventList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner] + serializer_class = EventSerializer + queryset = "" + + def post(self, request, format=None): + """ + Create an event for an experiment + """ + serializer = EventSerializer(data=request.data) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class EventDetail(GenericAPIView): + serializer_class = EventDetailSerializer + queryset = "" + + def get_permissions(self): + if self.request.method == "PUT": + self.permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] + return super(self.__class__, self).get_permissions() + + def get_object(self, tid): + try: + return TrainingEvent.objects.filter(training_exp__id=tid) + except TrainingEvent.DoesNotExist: + raise Http404 + + def get(self, request, tid, format=None): + """ + Retrieve events of a training experiment + """ + event = self.get_object(tid) + serializer = EventDetailSerializer(event, many=True) + return Response(serializer.data) + + def put(self, request, tid, format=None): + """ + Update latest event of a training experiment + """ + event = self.get_object(tid).order_by("-created_at").first() + serializer = EventDetailSerializer(event, data=request.data) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) From 9205d9842dccfe075c836fa2f2d551967732b0b1 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 22 Apr 2024 22:45:57 +0200 Subject: [PATCH 034/242] modify training event --- server/trainingevent/views.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/trainingevent/views.py b/server/trainingevent/views.py index a55d5a6e7..f9f3d2309 100644 --- a/server/trainingevent/views.py +++ b/server/trainingevent/views.py @@ -1,4 +1,4 @@ -from .models import TrainingEvent +from training.models import TrainingExperiment from django.http import Http404 from rest_framework.generics import GenericAPIView from rest_framework.response import Response @@ -35,23 +35,28 @@ def get_permissions(self): def get_object(self, tid): try: - return TrainingEvent.objects.filter(training_exp__id=tid) - except TrainingEvent.DoesNotExist: + training_exp = TrainingExperiment.objects.get(pk=tid) + except TrainingExperiment.DoesNotExist: raise Http404 + event = training_exp.event + if not event: + raise Http404 + return event + def get(self, request, tid, format=None): """ - Retrieve events of a training experiment + Retrieve latest event of a training experiment """ event = self.get_object(tid) - serializer = EventDetailSerializer(event, many=True) + serializer = EventDetailSerializer(event) return Response(serializer.data) def put(self, request, tid, format=None): """ Update latest event of a training experiment """ - event = self.get_object(tid).order_by("-created_at").first() + event = self.get_object(tid) serializer = EventDetailSerializer(event, data=request.data) if serializer.is_valid(): serializer.save() From 7b9b3a14534b617cfe2a1a3d5006ada2be9947f7 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 22 Apr 2024 23:27:40 +0200 Subject: [PATCH 035/242] modify training event --- server/trainingevent/serializers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/server/trainingevent/serializers.py b/server/trainingevent/serializers.py index 7f3974077..ca2080e3d 100644 --- a/server/trainingevent/serializers.py +++ b/server/trainingevent/serializers.py @@ -21,6 +21,16 @@ def validate(self, data): raise serializers.ValidationError( "User cannot create a new event unless the previous event has finished" ) + aggregator = training_exp.aggregator + if not aggregator: + raise serializers.ValidationError( + "User cannot create a new event if the experiment has no aggregator" + ) + plan = training_exp.plan + if plan is None: + raise serializers.ValidationError( + "User cannot create a new event if the experiment has no plan" + ) return data From dc6cadd4dba20d19f75d05a8a85ff626be9871f7 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 22 Apr 2024 23:27:52 +0200 Subject: [PATCH 036/242] changes to training entity --- server/training/migrations/0001_initial.py | 46 ------- server/training/models.py | 25 +++- server/training/permissions.py | 55 +++++++++ server/training/serializers.py | 126 +++++++++---------- server/training/urls.py | 6 +- server/training/views.py | 133 +++++++++++++++------ 6 files changed, 238 insertions(+), 153 deletions(-) delete mode 100644 server/training/migrations/0001_initial.py diff --git a/server/training/migrations/0001_initial.py b/server/training/migrations/0001_initial.py deleted file mode 100644 index c17bdb4a3..000000000 --- a/server/training/migrations/0001_initial.py +++ /dev/null @@ -1,46 +0,0 @@ -# Generated by Django 3.2.20 on 2023-09-29 01:02 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('mlcube', '0001_initial'), - ] - - operations = [ - migrations.CreateModel( - name='TrainingExperiment', - fields=[ - ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('name', models.CharField(max_length=20, unique=True)), - ('description', models.CharField(blank=True, max_length=100)), - ('docs_url', models.CharField(blank=True, max_length=100)), - ('demo_dataset_tarball_url', models.CharField(blank=True, max_length=256)), - ('demo_dataset_tarball_hash', models.CharField(max_length=100)), - ('demo_dataset_generated_uid', models.CharField(max_length=128)), - ('metadata', models.JSONField(blank=True, default=dict, null=True)), - ('state', models.CharField(choices=[('DEVELOPMENT', 'DEVELOPMENT'), ('OPERATION', 'OPERATION')], default='DEVELOPMENT', max_length=100)), - ('is_valid', models.BooleanField(default=True)), - ('approval_status', models.CharField(choices=[('PENDING', 'PENDING'), ('APPROVED', 'APPROVED'), ('REJECTED', 'REJECTED')], default='PENDING', max_length=100)), - ('private_key', models.CharField(blank=True, max_length=100)), - ('public_key', models.TextField(blank=True)), - ('user_metadata', models.JSONField(blank=True, default=dict, null=True)), - ('approved_at', models.DateTimeField(blank=True, null=True)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('modified_at', models.DateTimeField(auto_now=True)), - ('data_preparation_mlcube', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='training_exp', to='mlcube.mlcube')), - ('fl_mlcube', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='fl_mlcube', to='mlcube.mlcube')), - ('owner', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), - ], - options={ - 'ordering': ['modified_at'], - }, - ), - ] diff --git a/server/training/models.py b/server/training/models.py index d86651f53..0726ed87f 100644 --- a/server/training/models.py +++ b/server/training/models.py @@ -19,7 +19,7 @@ class TrainingExperiment(models.Model): description = models.CharField(max_length=100, blank=True) docs_url = models.CharField(max_length=100, blank=True) owner = models.ForeignKey(User, on_delete=models.PROTECT) - demo_dataset_tarball_url = models.CharField(max_length=256, blank=True) + demo_dataset_tarball_url = models.CharField(max_length=256) demo_dataset_tarball_hash = models.CharField(max_length=100) demo_dataset_generated_uid = models.CharField(max_length=128) data_preparation_mlcube = models.ForeignKey( @@ -40,10 +40,7 @@ class TrainingExperiment(models.Model): approval_status = models.CharField( choices=EXP_STATUS, max_length=100, default="PENDING" ) - private_key = models.CharField(max_length=100, blank=True) - public_key = models.TextField(blank=True) - # TODO: ensure unique, but allow blank (how?) - # TODO: rethink if keys are always needed + plan = models.JSONField(blank=True, null=True) user_metadata = models.JSONField(default=dict, blank=True, null=True) approved_at = models.DateTimeField(null=True, blank=True) @@ -53,5 +50,23 @@ class TrainingExperiment(models.Model): def __str__(self): return self.name + @property + def event(self): + return self.events.all().order_by("created_at").last() + + @property + def aggregator(self): + aggregator_assoc = ( + self.aggregator_association_set.all().order_by("created_at").last() + ) + if aggregator_assoc and aggregator_assoc.approval_status == "APPROVED": + return aggregator_assoc.aggregator + + @property + def ca(self): + ca_assoc = self.ca_association_set.all().order_by("created_at").last() + if ca_assoc and ca_assoc.approval_status == "APPROVED": + return ca_assoc.ca + class Meta: ordering = ["modified_at"] diff --git a/server/training/permissions.py b/server/training/permissions.py index 98e59e048..7576d489a 100644 --- a/server/training/permissions.py +++ b/server/training/permissions.py @@ -1,5 +1,7 @@ from rest_framework.permissions import BasePermission from .models import TrainingExperiment +from traindataset_association.models import ExperimentDataset +from django.db.models import OuterRef, Subquery class IsAdmin(BasePermission): @@ -25,3 +27,56 @@ def has_permission(self, request, view): return True else: return False + + +# TODO: check effciency / database costs +class IsAssociatedDatasetOwner(BasePermission): + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + + if not request.user.is_authenticated: + # This check is to prevent internal server error + # since user.dataset_set is used below + return False + + latest_datasets_assocs_status = ( + ExperimentDataset.objects.all() + .filter(training_exp__id=pk, dataset__id=OuterRef("id")) + .order_by("-created_at")[:1] + .values("approval_status") + ) + + user_associated_datasets = ( + request.user.dataset_set.all() + .annotate(assoc_status=Subquery(latest_datasets_assocs_status)) + .filter(assoc_status="APPROVED") + ) + + if user_associated_datasets.exists(): + return True + else: + return False + + +class IsAggregatorOwner(BasePermission): + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + + if not request.user.is_authenticated: + # This check is to prevent internal server error + # since user.dataset_set is used below + return False + + training_exp = TrainingExperiment.objects.get(pk=pk) + aggregator = training_exp.aggregator + if not aggregator: + return False + + if aggregator.owner.id == request.user.id: + return True + else: + return False diff --git a/server/training/serializers.py b/server/training/serializers.py index 4c7f5e9be..b9df8a5c9 100644 --- a/server/training/serializers.py +++ b/server/training/serializers.py @@ -1,20 +1,13 @@ from rest_framework import serializers -from .models import TrainingExperiment -from signing.interface import generate_key_pair from django.utils import timezone +from .models import TrainingExperiment class WriteTrainingExperimentSerializer(serializers.ModelSerializer): class Meta: model = TrainingExperiment - exclude = ["private_key"] - read_only_fields = [ - "owner", - "private_key", - "public_key", - "approved_at", - "approval_status", - ] + fields = "__all__" + read_only_fields = ["owner", "approved_at", "approval_status"] def validate(self, data): owner = self.context["request"].user @@ -25,74 +18,81 @@ def validate(self, data): raise serializers.ValidationError( "User can own at most one pending experiment" ) - return data - - def save(self, **kwargs): - super().save(**kwargs) - # TODO: move key generation after admin approval? YES - # TODO: use atomic transaction - private_key_id, public_key = generate_key_pair(self.instance.id) - self.instance.private_key = private_key_id - self.instance.public_key = public_key - self.instance.save() + if "state" in data and data["state"] == "OPERATION": + dev_mlcubes = [ + data["data_preparation_mlcube"].state == "DEVELOPMENT", + data["fl_mlcube"].state == "DEVELOPMENT", + ] + if any(dev_mlcubes): + raise serializers.ValidationError( + "User cannot mark an experiment as operational" + " if its MLCubes are not operational" + ) - return self.instance + return data class ReadTrainingExperimentSerializer(serializers.ModelSerializer): class Meta: model = TrainingExperiment - exclude = ["private_key"] + read_only_fields = ["owner", "approved_at"] + fields = "__all__" def update(self, instance, validated_data): - # TODO: seems buggy - if ( - instance.approval_status != "PENDING" - and "approval_status" in validated_data - and validated_data["approval_status"] == "APPROVED" - ): - instance.approved_at = timezone.now() + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + validated_data.pop("approval_status", None) for k, v in validated_data.items(): setattr(instance, k, v) instance.save() return instance - def validate(self, data): - if "approval_status" in data: - if ( - data["approval_status"] == "PENDING" - and self.instance.approval_status != "PENDING" - ): - pending_experiments = TrainingExperiment.objects.filter( - owner=self.instance.owner, approval_status="PENDING" + def validate_approval_status(self, approval_status): + if approval_status == "PENDING": + raise serializers.ValidationError( + "User can only approve or reject an experiment" + ) + if approval_status == "APPROVED": + if self.instance.approval_status == "REJECTED": + raise serializers.ValidationError( + "User can approve only a pending request" ) - if len(pending_experiments) > 0: - raise serializers.ValidationError( - "User can own at most one pending experiment" - ) + return approval_status - editable_fields = [ - "is_valid", - "user_metadata", - "approval_status", - "demo_dataset_tarball_url", - ] - if self.instance.state == "DEVELOPMENT": - editable_fields.append("state") + def validate_state(self, state): + if state == "OPERATION" and self.instance.state != "OPERATION": + dev_mlcubes = [ + self.instance.data_preparation_mlcube.state == "DEVELOPMENT", + self.instance.fl_mlcube.state == "DEVELOPMENT", + ] + if any(dev_mlcubes): + raise serializers.ValidationError( + "User cannot mark an experiment as operational" + " if its MLCubes are not operational" + ) + return state - for k, v in data.items(): - if k not in editable_fields: - if v != getattr(self.instance, k): - raise serializers.ValidationError( - "User cannot update non editable fields" - ) - if ( - "state" in data - and data["state"] == "OPERATION" - and self.instance.state == "DEVELOPMENT" - ): - # TODO: check if there is an approved aggregator other wise raise - # and at least one approved dataset?? - pass + def validate(self, data): + event = self.instance.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot update an experiment with ongoing event" + ) + if self.instance.state == "OPERATION": + editable_fields = [ + "is_valid", + "user_metadata", + "approval_status", + "demo_dataset_tarball_url", + ] + for k, v in data.items(): + if k not in editable_fields: + if v != getattr(self.instance, k): + raise serializers.ValidationError( + "User cannot update non editable fields in Operation mode" + ) return data diff --git a/server/training/urls.py b/server/training/urls.py index ff298550b..ac3ad96cb 100644 --- a/server/training/urls.py +++ b/server/training/urls.py @@ -1,5 +1,6 @@ from django.urls import path from . import views +import trainingevent.views as pviews app_name = "training" @@ -7,5 +8,8 @@ path("", views.TrainingExperimentList.as_view()), path("/", views.TrainingExperimentDetail.as_view()), path("/datasets/", views.TrainingDatasetList.as_view()), - path("/aggregator/", views.GetAggregator.as_view()), + path("/aggregator/", views.TrainingAggregator.as_view()), + path("/ca/", views.TrainingCA.as_view()), + path("/plan/", pviews.EventDetail.as_view()), + path("plans/", pviews.EventList.as_view()), ] diff --git a/server/training/views.py b/server/training/views.py index 1e3f6c86a..dd46c10bc 100644 --- a/server/training/views.py +++ b/server/training/views.py @@ -1,19 +1,27 @@ +from aggregator.serializers import ( + AggregatorSerializer, +) +from traindataset_association.serializers import ( + TrainingExperimentListofDatasetsSerializer, +) +from ca.serializers import CASerializer from django.http import Http404 from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema from .models import TrainingExperiment from .serializers import ( WriteTrainingExperimentSerializer, ReadTrainingExperimentSerializer, ) -from .permissions import IsAdmin, IsExpOwner -from dataset.serializers import DatasetPublicSerializer -from aggregator.serializers import AggregatorSerializer -from drf_spectacular.utils import extend_schema -from aggregator_association.utils import latest_agg_associations -from traindataset_association.utils import latest_data_associations +from .permissions import ( + IsAdmin, + IsExpOwner, + IsAssociatedDatasetOwner, + IsAggregatorOwner, +) class TrainingExperimentList(GenericAPIView): @@ -43,6 +51,80 @@ def post(self, request, format=None): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +class TrainingAggregator(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner | IsAssociatedDatasetOwner] + serializer_class = AggregatorSerializer + queryset = "" + + def get_object(self, pk): + try: + training_exp = TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + aggregator = training_exp.aggregator + if not aggregator: + raise Http404 + return aggregator + + def get(self, request, pk, format=None): + """ + Retrieve the aggregator associated with a training exp instance. + """ + aggregator = self.get_object(pk) + serializer = AggregatorSerializer(aggregator) + return Response(serializer.data) + + +class TrainingDatasetList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner] + serializer_class = TrainingExperimentListofDatasetsSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve datasets associated with a training experiment instance. + """ + training_exp = self.get_object(pk) + datasets = training_exp.traindataset_association_set.all() + datasets = self.paginate_queryset(datasets) + serializer = TrainingExperimentListofDatasetsSerializer(datasets, many=True) + return self.get_paginated_response(serializer.data) + + +class TrainingCA(GenericAPIView): + permission_classes = [ + IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + ] + serializer_class = CASerializer + queryset = "" + + def get_object(self, pk): + try: + training_exp = TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + ca = training_exp.ca + if not ca: + raise Http404 + return ca + + def get(self, request, pk, format=None): + """ + Retrieve CA associated with a training experiment instance. + """ + ca = self.get_object(pk) + serializer = CASerializer(ca) + return Response(serializer.data) + + class TrainingExperimentDetail(GenericAPIView): serializer_class = ReadTrainingExperimentSerializer queryset = "" @@ -52,6 +134,8 @@ def get_permissions(self): self.permission_classes = [IsAdmin | IsExpOwner] if "approval_status" in self.request.data: self.permission_classes = [IsAdmin] + elif self.request.method == "DELETE": + self.permission_classes = [IsAdmin] return super(self.__class__, self).get_permissions() def get_object(self, pk): @@ -81,37 +165,10 @@ def put(self, request, pk, format=None): return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - -class TrainingDatasetList(GenericAPIView): - serializer_class = DatasetPublicSerializer - queryset = "" - - def get(self, request, pk, format=None): + def delete(self, request, pk, format=None): """ - Retrieve datasets associated with a training_exp instance. + Delete a training experiment instance. """ - experiment_datasets = latest_data_associations(pk) - experiment_datasets = experiment_datasets.filter(approval_status="APPROVED") - datasets = [exp_dset.dataset for exp_dset in experiment_datasets] - datasets = self.paginate_queryset(datasets) - serializer = DatasetPublicSerializer(datasets, many=True) - return self.get_paginated_response(serializer.data) - - -class GetAggregator(GenericAPIView): - serializer_class = AggregatorSerializer - queryset = "" - - def get(self, request, pk, format=None): - """ - Retrieve aggregator associated with a training exp instance. - """ - experiment_aggregators = latest_agg_associations(pk) - experiment_aggregators = experiment_aggregators.filter( - approval_status="APPROVED" - ) - aggregators = [exp_agg.aggregator for exp_agg in experiment_aggregators] - if aggregators: - serializer = AggregatorSerializer(aggregators[0]) - return Response(serializer.data) - return Response({}, status=status.HTTP_400_BAD_REQUEST) + training_exp = self.get_object(pk) + training_exp.delete() + return Response(status=status.HTTP_204_NO_CONTENT) From acc67c1406b717483b938b44570d4de47192315b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 23 Apr 2024 00:10:50 +0200 Subject: [PATCH 037/242] traindataset association updates --- server/dataset/urls.py | 3 +- .../migrations/0001_initial.py | 36 ----- .../0002_experimentdataset_training_exp.py | 22 --- server/traindataset_association/models.py | 4 +- .../traindataset_association/serializers.py | 129 ++++++++---------- server/traindataset_association/utils.py | 14 -- server/traindataset_association/views.py | 33 ++++- 7 files changed, 92 insertions(+), 149 deletions(-) delete mode 100644 server/traindataset_association/migrations/0001_initial.py delete mode 100644 server/traindataset_association/migrations/0002_experimentdataset_training_exp.py delete mode 100644 server/traindataset_association/utils.py diff --git a/server/dataset/urls.py b/server/dataset/urls.py index ff07b54dd..a4186fa64 100644 --- a/server/dataset/urls.py +++ b/server/dataset/urls.py @@ -11,7 +11,8 @@ path("benchmarks/", bviews.BenchmarkDatasetList.as_view()), path("/benchmarks//", bviews.DatasetApproval.as_view()), # path("/benchmarks/", bviews.DatasetBenchmarksList.as_view()), - # NOTE: when activating this endpoint later, check permissions and write tests + # path("/training_experiments/", tviews.DatasetExperimentList.as_view()), + # NOTE: when activating those two endpoints later, check permissions and write tests path("training_experiments/", tviews.ExperimentDatasetList.as_view()), path("/training_experiments//", tviews.DatasetApproval.as_view()), ] diff --git a/server/traindataset_association/migrations/0001_initial.py b/server/traindataset_association/migrations/0001_initial.py deleted file mode 100644 index 23e4fc840..000000000 --- a/server/traindataset_association/migrations/0001_initial.py +++ /dev/null @@ -1,36 +0,0 @@ -# Generated by Django 3.2.20 on 2023-09-29 01:02 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('dataset', '0001_initial'), - ] - - operations = [ - migrations.CreateModel( - name='ExperimentDataset', - fields=[ - ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('certificate', models.TextField(blank=True)), - ('signing_request', models.TextField()), - ('metadata', models.JSONField(default=dict)), - ('approval_status', models.CharField(choices=[('PENDING', 'PENDING'), ('APPROVED', 'APPROVED'), ('REJECTED', 'REJECTED')], default='PENDING', max_length=100)), - ('approved_at', models.DateTimeField(blank=True, null=True)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('modified_at', models.DateTimeField(auto_now=True)), - ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to='dataset.dataset')), - ('initiated_by', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), - ], - options={ - 'ordering': ['created_at'], - }, - ), - ] diff --git a/server/traindataset_association/migrations/0002_experimentdataset_training_exp.py b/server/traindataset_association/migrations/0002_experimentdataset_training_exp.py deleted file mode 100644 index 807ed44ae..000000000 --- a/server/traindataset_association/migrations/0002_experimentdataset_training_exp.py +++ /dev/null @@ -1,22 +0,0 @@ -# Generated by Django 3.2.20 on 2023-09-29 01:02 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ('traindataset_association', '0001_initial'), - ('training', '0001_initial'), - ] - - operations = [ - migrations.AddField( - model_name='experimentdataset', - name='training_exp', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='training.trainingexperiment'), - ), - ] diff --git a/server/traindataset_association/models.py b/server/traindataset_association/models.py index 460e8a5db..dc71107ca 100644 --- a/server/traindataset_association/models.py +++ b/server/traindataset_association/models.py @@ -10,8 +10,6 @@ class ExperimentDataset(models.Model): ("APPROVED", "APPROVED"), ("REJECTED", "REJECTED"), ) - certificate = models.TextField(blank=True) - signing_request = models.TextField() dataset = models.ForeignKey("dataset.Dataset", on_delete=models.PROTECT) training_exp = models.ForeignKey( "training.TrainingExperiment", on_delete=models.CASCADE @@ -26,4 +24,4 @@ class ExperimentDataset(models.Model): modified_at = models.DateTimeField(auto_now=True) class Meta: - ordering = ["created_at"] + ordering = ["modified_at"] diff --git a/server/traindataset_association/serializers.py b/server/traindataset_association/serializers.py index ffc74414d..3f2d8d88c 100644 --- a/server/traindataset_association/serializers.py +++ b/server/traindataset_association/serializers.py @@ -4,72 +4,61 @@ from dataset.models import Dataset from .models import ExperimentDataset -from signing.interface import verify_dataset_csr, sign_csr +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) class ExperimentDatasetListSerializer(serializers.ModelSerializer): class Meta: model = ExperimentDataset - read_only_fields = ["initiated_by", "approved_at", "certificate"] + read_only_fields = ["initiated_by", "approved_at"] fields = "__all__" def validate(self, data): - exp_id = self.context["request"].data.get("training_exp") + tid = self.context["request"].data.get("training_exp") dataset = self.context["request"].data.get("dataset") - approval_status = self.context["request"].data.get("approval_status") - csr = self.context["request"].data.get("signing_request") + approval_status = self.context["request"].data.get("approval_status", "PENDING") - training_exp = TrainingExperiment.objects.get(pk=exp_id) - training_exp_state = training_exp.state + training_exp = TrainingExperiment.objects.get(pk=tid) - if training_exp_state != "DEVELOPMENT": - raise serializers.ValidationError( - "Dataset Association requests can be made only " - "on a development training experiment" - ) + # training_exp approval status training_exp_approval_status = training_exp.approval_status if training_exp_approval_status != "APPROVED": raise serializers.ValidationError( "Association requests can be made only on an approved training experiment" ) - dataset_object = Dataset.objects.get(pk=dataset) - dataset_state = dataset_object.state + + # training_exp event status + event = training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "The training experiment does not currently accept associations" + ) + + # dataset state + dataset_obj = Dataset.objects.get(pk=dataset) + dataset_state = dataset_obj.state if dataset_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational dataset" ) - last_experiment_dataset = ( - ExperimentDataset.objects.filter( - training_exp__id=exp_id, dataset__id=dataset + + # dataset prep mlcube + if dataset_obj.data_preparation_mlcube != training_exp.data_preparation_mlcube: + raise serializers.ValidationError( + "Dataset association request can be made only if the dataset" + " was prepared with the training experiment's data preparation MLCube" ) + + # approval status + last_training_expdataset = ( + ExperimentDataset.objects.filter(training_exp__id=tid, dataset__id=dataset) .order_by("-created_at") .first() ) - if not last_experiment_dataset: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_experiment_dataset.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_experiment_dataset.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) - - valid_csr, reason = verify_dataset_csr(csr, dataset_object, training_exp) - if not valid_csr: - raise serializers.ValidationError(reason) + validate_approval_status_on_creation(last_training_expdataset, approval_status) return data @@ -78,59 +67,59 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( + same_owner = ( validated_data["dataset"].owner.id == validated_data["training_exp"].owner.id - ): + ) + if same_owner: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() - csr = validated_data["signing_request"] - certificate = sign_csr(csr, validated_data["training_exp"]) - validated_data["certificate"] = certificate return ExperimentDataset.objects.create(**validated_data) class DatasetApprovalSerializer(serializers.ModelSerializer): class Meta: model = ExperimentDataset - read_only_fields = ["initiated_by", "approved_at", "certificate"] + read_only_fields = ["initiated_by", "approved_at"] fields = [ "approval_status", "initiated_by", "approved_at", "created_at", "modified_at", - "certificate", ] def validate(self, data): if not self.instance: raise serializers.ValidationError("No dataset association found") + return data + + def validate_approval_status(self, cur_approval_status): last_approval_status = self.instance.approval_status - cur_approval_status = data["approval_status"] - if last_approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject only a pending request" - ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if ( - last_approval_status != cur_approval_status - and cur_approval_status == "APPROVED" - ): - if current_user.id == initiated_user.id: - raise serializers.ValidationError( - "Same user cannot approve the association request" - ) - return data + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) + + event = self.instance.training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot approve or reject an association when the experiment is ongoing" + ) + return cur_approval_status def update(self, instance, validated_data): - instance.approval_status = validated_data["approval_status"] - if instance.approval_status != "PENDING": - instance.approved_at = timezone.now() - if instance.approval_status == "APPROVED": - csr = instance.signing_request - certificate = sign_csr(csr, self.instance.training_exp.id) - instance.certificate = certificate + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() instance.save() return instance + + +class TrainingExperimentListofDatasetsSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentDataset + fields = ["dataset", "approval_status", "created_at"] diff --git a/server/traindataset_association/utils.py b/server/traindataset_association/utils.py deleted file mode 100644 index 7143d351a..000000000 --- a/server/traindataset_association/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from django.db.models import OuterRef, Subquery -from .models import ExperimentDataset - - -def latest_data_associations(training_exp_id): - experiment_datasets = ExperimentDataset.objects.filter( - training_exp__id=training_exp_id - ) - latest_assocs = ( - experiment_datasets.filter(dataset=OuterRef("dataset")) - .order_by("-created_at") - .values("id")[:1] - ) - return experiment_datasets.filter(id__in=Subquery(latest_assocs)) diff --git a/server/traindataset_association/views.py b/server/traindataset_association/views.py index 0496a036e..729e298f9 100644 --- a/server/traindataset_association/views.py +++ b/server/traindataset_association/views.py @@ -3,6 +3,7 @@ from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status +from drf_spectacular.utils import extend_schema from .permissions import IsAdmin, IsDatasetOwner, IsExpOwner from .serializers import ( @@ -29,11 +30,37 @@ def post(self, request, format=None): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +class DatasetExperimentList(GenericAPIView): + serializer_class = ExperimentDatasetListSerializer + queryset = "" + + def get_object(self, pk): + try: + return ExperimentDataset.objects.filter(dataset__id=pk) + except ExperimentDataset.DoesNotExist: + raise Http404 + + @extend_schema(operation_id="datasets_experiments_retrieve_all") + def get(self, request, pk, format=None): + """ + Retrieve all experiments associated with a dataset + """ + training_expdataset = self.get_object(pk) + training_expdataset = self.paginate_queryset(training_expdataset) + serializer = ExperimentDatasetListSerializer(training_expdataset, many=True) + return self.get_paginated_response(serializer.data) + + class DatasetApproval(GenericAPIView): - permission_classes = [IsAdmin | IsExpOwner | IsDatasetOwner] serializer_class = DatasetApprovalSerializer queryset = "" + def get_permissions(self): + self.permission_classes = [IsAdmin | IsExpOwner | IsDatasetOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + def get_object(self, dataset_id, training_exp_id): try: return ExperimentDataset.objects.filter( @@ -46,8 +73,8 @@ def get(self, request, pk, tid, format=None): """ Retrieve approval status of training_exp dataset associations """ - training_expdataset = self.get_object(pk, tid).order_by("-created_at").first() - serializer = DatasetApprovalSerializer(training_expdataset) + training_expdataset = self.get_object(pk, tid) + serializer = DatasetApprovalSerializer(training_expdataset, many=True) return Response(serializer.data) def put(self, request, pk, tid, format=None): From c122539a7612b4d46191ce07f10027471f4ff56d Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 23 Apr 2024 03:03:47 +0200 Subject: [PATCH 038/242] add aggregator_association --- server/aggregator/urls.py | 3 + .../migrations/0001_initial.py | 36 ----- .../0002_experimentaggregator_training_exp.py | 22 --- server/aggregator_association/models.py | 2 - server/aggregator_association/serializers.py | 140 +++++++----------- server/aggregator_association/utils.py | 14 -- server/aggregator_association/views.py | 9 +- 7 files changed, 62 insertions(+), 164 deletions(-) delete mode 100644 server/aggregator_association/migrations/0001_initial.py delete mode 100644 server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py delete mode 100644 server/aggregator_association/utils.py diff --git a/server/aggregator/urls.py b/server/aggregator/urls.py index 08ef74148..8641718c7 100644 --- a/server/aggregator/urls.py +++ b/server/aggregator/urls.py @@ -1,9 +1,12 @@ from django.urls import path from . import views +import aggregator_association.views as tviews app_name = "aggregator" urlpatterns = [ path("", views.AggregatorList.as_view()), path("/", views.AggregatorDetail.as_view()), + path("training/", tviews.ExperimentAggregatorList.as_view()), + path("/training//", tviews.AggregatorApproval.as_view()), ] diff --git a/server/aggregator_association/migrations/0001_initial.py b/server/aggregator_association/migrations/0001_initial.py deleted file mode 100644 index f6f0825e5..000000000 --- a/server/aggregator_association/migrations/0001_initial.py +++ /dev/null @@ -1,36 +0,0 @@ -# Generated by Django 3.2.20 on 2023-09-29 01:02 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('aggregator', '0001_initial'), - ] - - operations = [ - migrations.CreateModel( - name='ExperimentAggregator', - fields=[ - ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('certificate', models.TextField(blank=True)), - ('signing_request', models.TextField()), - ('metadata', models.JSONField(default=dict)), - ('approval_status', models.CharField(choices=[('PENDING', 'PENDING'), ('APPROVED', 'APPROVED'), ('REJECTED', 'REJECTED')], default='PENDING', max_length=100)), - ('approved_at', models.DateTimeField(blank=True, null=True)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('modified_at', models.DateTimeField(auto_now=True)), - ('aggregator', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to='aggregator.aggregator')), - ('initiated_by', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL)), - ], - options={ - 'ordering': ['created_at'], - }, - ), - ] diff --git a/server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py b/server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py deleted file mode 100644 index 536303150..000000000 --- a/server/aggregator_association/migrations/0002_experimentaggregator_training_exp.py +++ /dev/null @@ -1,22 +0,0 @@ -# Generated by Django 3.2.20 on 2023-09-29 01:02 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ('aggregator_association', '0001_initial'), - ('training', '0001_initial'), - ] - - operations = [ - migrations.AddField( - model_name='experimentaggregator', - name='training_exp', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='training.trainingexperiment'), - ), - ] diff --git a/server/aggregator_association/models.py b/server/aggregator_association/models.py index 0b069ca96..3642679c1 100644 --- a/server/aggregator_association/models.py +++ b/server/aggregator_association/models.py @@ -10,8 +10,6 @@ class ExperimentAggregator(models.Model): ("APPROVED", "APPROVED"), ("REJECTED", "REJECTED"), ) - certificate = models.TextField(blank=True) - signing_request = models.TextField() aggregator = models.ForeignKey("aggregator.Aggregator", on_delete=models.PROTECT) training_exp = models.ForeignKey( "training.TrainingExperiment", on_delete=models.CASCADE diff --git a/server/aggregator_association/serializers.py b/server/aggregator_association/serializers.py index 0360f1ed5..712d7e3b7 100644 --- a/server/aggregator_association/serializers.py +++ b/server/aggregator_association/serializers.py @@ -1,85 +1,59 @@ from rest_framework import serializers from django.utils import timezone from training.models import TrainingExperiment -from aggregator.models import Aggregator from .models import ExperimentAggregator -from .utils import latest_agg_associations -from signing.interface import verify_aggregator_csr, sign_csr +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) class ExperimentAggregatorListSerializer(serializers.ModelSerializer): class Meta: model = ExperimentAggregator - read_only_fields = ["initiated_by", "approved_at", "certificate"] + read_only_fields = ["initiated_by", "approved_at"] fields = "__all__" def validate(self, data): - exp_id = self.context["request"].data.get("training_exp") + tid = self.context["request"].data.get("training_exp") aggregator = self.context["request"].data.get("aggregator") - approval_status = self.context["request"].data.get("approval_status") - csr = self.context["request"].data.get("signing_request") + approval_status = self.context["request"].data.get("approval_status", "PENDING") - training_exp = TrainingExperiment.objects.get(pk=exp_id) - training_exp_state = training_exp.state + training_exp = TrainingExperiment.objects.get(pk=tid) - if training_exp_state != "DEVELOPMENT": - raise serializers.ValidationError( - "Aggregator Association requests can be made only " - "on a development training experiment" - ) + # training_exp approval status training_exp_approval_status = training_exp.approval_status if training_exp_approval_status != "APPROVED": raise serializers.ValidationError( "Association requests can be made only on an approved training experiment" ) - aggregator_object = Aggregator.objects.get(pk=aggregator) + + # training_exp event status + event = training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "The training experiment does not currently accept associations" + ) + + # An already approved aggregator + exp_aggregator = training_exp.aggregator + if exp_aggregator and exp_aggregator.id != aggregator: + raise serializers.ValidationError( + "The training experiment already has an aggregator" + ) + + # approval status last_experiment_aggregator = ( ExperimentAggregator.objects.filter( - training_exp__id=exp_id, aggregator__id=aggregator + training_exp__id=tid, aggregator__id=aggregator ) .order_by("-created_at") .first() ) - if not last_experiment_aggregator: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_experiment_aggregator.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_experiment_aggregator.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) - - # check if there is already an approved aggregator - # TODO: concurrency problem perhaps? if a user creates simultanuously two - # already APPROVED associations - experiment_aggregators = latest_agg_associations(exp_id) - approved_experiment_aggregators = experiment_aggregators.filter( - approval_status="APPROVED" - ) - if approved_experiment_aggregators.exists(): - raise serializers.ValidationError( - "This training experiment already has an aggregator" - ) - - valid_csr, reason = verify_aggregator_csr( - csr, aggregator_object, training_exp, self.context["request"] + validate_approval_status_on_creation( + last_experiment_aggregator, approval_status ) - if not valid_csr: - raise serializers.ValidationError(reason) return data @@ -88,69 +62,59 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( + same_owner = ( validated_data["aggregator"].owner.id == validated_data["training_exp"].owner.id - ): + ) + if same_owner: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() - csr = validated_data["signing_request"] - certificate = sign_csr(csr, validated_data["training_exp"]) - validated_data["certificate"] = certificate return ExperimentAggregator.objects.create(**validated_data) class AggregatorApprovalSerializer(serializers.ModelSerializer): class Meta: model = ExperimentAggregator - read_only_fields = ["initiated_by", "approved_at", "certificate"] + read_only_fields = ["initiated_by", "approved_at"] fields = [ "approval_status", "initiated_by", "approved_at", "created_at", "modified_at", - "certificate", ] def validate(self, data): if not self.instance: raise serializers.ValidationError("No aggregator association found") - last_approval_status = self.instance.approval_status - cur_approval_status = data["approval_status"] - if last_approval_status != "PENDING": + # check if there is already an approved aggregator + exp_aggregator = self.instance.training_exp.aggregator + if exp_aggregator and exp_aggregator.id != self.instance.aggregator.id: raise serializers.ValidationError( - "User can approve or reject only a pending request" + "The training experiment already has an aggregator" ) + return data + + def validate_approval_status(self, cur_approval_status): + last_approval_status = self.instance.approval_status initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if ( - last_approval_status != cur_approval_status - and cur_approval_status == "APPROVED" - ): - if current_user.id == initiated_user.id: - raise serializers.ValidationError( - "Same user cannot approve the association request" - ) - - # check if there is already an approved aggregator - experiment_aggregators = latest_agg_associations(self.instance.training_exp.id) - approved_experiment_aggregators = experiment_aggregators.filter( - approval_status="APPROVED" + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user ) - if approved_experiment_aggregators.exists(): + + event = self.instance.training_exp.event + if event and not event.finished: raise serializers.ValidationError( - "This training experiment already has an aggregator" + "User cannot approve or reject an association when the experiment is ongoing" ) - return data + return cur_approval_status def update(self, instance, validated_data): - instance.approval_status = validated_data["approval_status"] - if instance.approval_status != "PENDING": - instance.approved_at = timezone.now() - if instance.approval_status == "APPROVED": - csr = instance.signing_request - certificate = sign_csr(csr, self.instance.training_exp.id) - instance.certificate = certificate + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() instance.save() return instance diff --git a/server/aggregator_association/utils.py b/server/aggregator_association/utils.py deleted file mode 100644 index 6a447210d..000000000 --- a/server/aggregator_association/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from django.db.models import OuterRef, Subquery -from .models import ExperimentAggregator - - -def latest_agg_associations(training_exp_id): - experiment_aggregators = ExperimentAggregator.objects.filter( - training_exp__id=training_exp_id - ) - latest_assocs = ( - experiment_aggregators.filter(aggregator=OuterRef("aggregator")) - .order_by("-created_at") - .values("id")[:1] - ) - return experiment_aggregators.filter(id__in=Subquery(latest_assocs)) diff --git a/server/aggregator_association/views.py b/server/aggregator_association/views.py index 846289e82..0ab6dca75 100644 --- a/server/aggregator_association/views.py +++ b/server/aggregator_association/views.py @@ -12,7 +12,7 @@ class ExperimentAggregatorList(GenericAPIView): - permission_classes = [IsAdmin | IsAggregatorOwner] + permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] serializer_class = ExperimentAggregatorListSerializer queryset = "" @@ -30,10 +30,15 @@ def post(self, request, format=None): class AggregatorApproval(GenericAPIView): - permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] serializer_class = AggregatorApprovalSerializer queryset = "" + def get_permissions(self): + self.permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + def get_object(self, aggregator_id, training_exp_id): try: return ExperimentAggregator.objects.filter( From 09d49c1908e22a4a62e62d7c6486e4190a14fbf9 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 23 Apr 2024 03:04:10 +0200 Subject: [PATCH 039/242] add ca_association --- server/ca/urls.py | 3 + server/ca_association/__init__.py | 0 server/ca_association/admin.py | 3 + server/ca_association/apps.py | 6 + server/ca_association/migrations/__init__.py | 0 server/ca_association/models.py | 27 +++++ server/ca_association/permissions.py | 54 +++++++++ server/ca_association/serializers.py | 117 +++++++++++++++++++ server/ca_association/views.py | 77 ++++++++++++ server/medperf/settings.py | 5 + 10 files changed, 292 insertions(+) create mode 100644 server/ca_association/__init__.py create mode 100644 server/ca_association/admin.py create mode 100644 server/ca_association/apps.py create mode 100644 server/ca_association/migrations/__init__.py create mode 100644 server/ca_association/models.py create mode 100644 server/ca_association/permissions.py create mode 100644 server/ca_association/serializers.py create mode 100644 server/ca_association/views.py diff --git a/server/ca/urls.py b/server/ca/urls.py index 70515aaf6..45d7ce343 100644 --- a/server/ca/urls.py +++ b/server/ca/urls.py @@ -1,9 +1,12 @@ from django.urls import path from . import views +import ca_association.views as tviews app_name = "ca" urlpatterns = [ path("", views.CAList.as_view()), path("/", views.CADetail.as_view()), + path("training/", tviews.ExperimentCAList.as_view()), + path("/training//", tviews.CAApproval.as_view()), ] diff --git a/server/ca_association/__init__.py b/server/ca_association/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca_association/admin.py b/server/ca_association/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/server/ca_association/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/server/ca_association/apps.py b/server/ca_association/apps.py new file mode 100644 index 000000000..9fba9e4c1 --- /dev/null +++ b/server/ca_association/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CAAssociationConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "ca_association" diff --git a/server/ca_association/migrations/__init__.py b/server/ca_association/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/ca_association/models.py b/server/ca_association/models.py new file mode 100644 index 000000000..f67e4a21e --- /dev/null +++ b/server/ca_association/models.py @@ -0,0 +1,27 @@ +from django.db import models +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class ExperimentCA(models.Model): + MODEL_STATUS = ( + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ) + ca = models.ForeignKey("ca.CA", on_delete=models.PROTECT) + training_exp = models.ForeignKey( + "training.TrainingExperiment", on_delete=models.CASCADE + ) + initiated_by = models.ForeignKey(User, on_delete=models.PROTECT) + metadata = models.JSONField(default=dict) + approval_status = models.CharField( + choices=MODEL_STATUS, max_length=100, default="PENDING" + ) + approved_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["created_at"] diff --git a/server/ca_association/permissions.py b/server/ca_association/permissions.py new file mode 100644 index 000000000..640422df1 --- /dev/null +++ b/server/ca_association/permissions.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import BasePermission +from training.models import TrainingExperiment +from ca.models import CA + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.is_superuser + + +class IsCAOwner(BasePermission): + def get_object(self, pk): + try: + return CA.objects.get(pk=pk) + except CA.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("ca", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + ca = self.get_object(pk) + if not ca: + return False + if ca.owner.id == request.user.id: + return True + else: + return False + + +class IsExpOwner(BasePermission): + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + return None + + def has_permission(self, request, view): + if request.method == "POST": + pk = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("tid", None) + if not pk: + return False + training_exp = self.get_object(pk) + if not training_exp: + return False + if training_exp.owner.id == request.user.id: + return True + else: + return False diff --git a/server/ca_association/serializers.py b/server/ca_association/serializers.py new file mode 100644 index 000000000..9e417515a --- /dev/null +++ b/server/ca_association/serializers.py @@ -0,0 +1,117 @@ +from rest_framework import serializers +from django.utils import timezone +from training.models import TrainingExperiment +from django.conf import settings + +from .models import ExperimentCA +from utils.associations import ( + validate_approval_status_on_creation, + validate_approval_status_on_update, +) + + +class ExperimentCAListSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentCA + read_only_fields = ["initiated_by", "approved_at"] + fields = "__all__" + + def validate(self, data): + tid = self.context["request"].data.get("training_exp") + ca = self.context["request"].data.get("ca") + approval_status = self.context["request"].data.get("approval_status", "PENDING") + + training_exp = TrainingExperiment.objects.get(pk=tid) + + # training_exp approval status + training_exp_approval_status = training_exp.approval_status + if training_exp_approval_status != "APPROVED": + raise serializers.ValidationError( + "Association requests can be made only on an approved training experiment" + ) + + # training_exp event status + event = training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "The training experiment does not currently accept associations" + ) + + # An already approved ca + exp_ca = training_exp.ca + if exp_ca and exp_ca.id != ca: + raise serializers.ValidationError( + "The training experiment already has an ca" + ) + + # approval status + last_experiment_ca = ( + ExperimentCA.objects.filter(training_exp__id=tid, ca__id=ca) + .order_by("-created_at") + .first() + ) + validate_approval_status_on_creation(last_experiment_ca, approval_status) + + return data + + def create(self, validated_data): + approval_status = validated_data.get("approval_status", "PENDING") + if approval_status != "PENDING": + validated_data["approved_at"] = timezone.now() + else: + same_owner = ( + validated_data["ca"].owner.id == validated_data["training_exp"].owner.id + ) + is_main_ca = validated_data["ca"].name == settings.CA_NAME + if same_owner or is_main_ca: + validated_data["approval_status"] = "APPROVED" + validated_data["approved_at"] = timezone.now() + return ExperimentCA.objects.create(**validated_data) + + +class CAApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = ExperimentCA + read_only_fields = ["initiated_by", "approved_at"] + fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + ] + + def validate(self, data): + if not self.instance: + raise serializers.ValidationError("No ca association found") + # check if there is already an approved ca + exp_ca = self.instance.training_exp.ca + if exp_ca and exp_ca.id != self.instance.ca.id: + raise serializers.ValidationError( + "The training experiment already has an ca" + ) + return data + + def validate_approval_status(self, cur_approval_status): + last_approval_status = self.instance.approval_status + initiated_user = self.instance.initiated_by + current_user = self.context["request"].user + validate_approval_status_on_update( + last_approval_status, cur_approval_status, initiated_user, current_user + ) + + event = self.instance.training_exp.event + if event and not event.finished: + raise serializers.ValidationError( + "User cannot approve or reject an association when the experiment is ongoing" + ) + return cur_approval_status + + def update(self, instance, validated_data): + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + instance.save() + return instance diff --git a/server/ca_association/views.py b/server/ca_association/views.py new file mode 100644 index 000000000..4cf05bee0 --- /dev/null +++ b/server/ca_association/views.py @@ -0,0 +1,77 @@ +from .models import ExperimentCA +from django.http import Http404 +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response +from rest_framework import status + +from .permissions import IsAdmin, IsCAOwner, IsExpOwner +from .serializers import ( + ExperimentCAListSerializer, + CAApprovalSerializer, +) + + +class ExperimentCAList(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner | IsCAOwner] + serializer_class = ExperimentCAListSerializer + queryset = "" + + def post(self, request, format=None): + """ + Associate a ca to a training_exp + """ + serializer = ExperimentCAListSerializer( + data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save(initiated_by=request.user) + return Response(serializer.data, status=status.HTTP_201_CREATED) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class CAApproval(GenericAPIView): + serializer_class = CAApprovalSerializer + queryset = "" + + def get_permissions(self): + self.permission_classes = [IsAdmin | IsExpOwner | IsCAOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + + def get_object(self, ca_id, training_exp_id): + try: + return ExperimentCA.objects.filter( + ca__id=ca_id, training_exp__id=training_exp_id + ) + except ExperimentCA.DoesNotExist: + raise Http404 + + def get(self, request, pk, tid, format=None): + """ + Retrieve approval status of training_exp ca associations + """ + training_expca = self.get_object(pk, tid).order_by("-created_at").first() + serializer = CAApprovalSerializer(training_expca) + return Response(serializer.data) + + def put(self, request, pk, tid, format=None): + """ + Update approval status of the last training_exp ca association + """ + training_expca = self.get_object(pk, tid).order_by("-created_at").first() + serializer = CAApprovalSerializer( + training_expca, data=request.data, context={"request": request} + ) + if serializer.is_valid(): + serializer.save() + return Response(serializer.data) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk, tid, format=None): + """ + Delete a training_exp ca association + """ + training_expca = self.get_object(pk, tid) + training_expca.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/server/medperf/settings.py b/server/medperf/settings.py index 6117eec8f..f77dd385e 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -62,6 +62,11 @@ SUPERUSER_PASSWORD = env("SUPERUSER_PASSWORD") +CA_NAME = "MedPerf CA" +CA_ADDRESS = env("CA_ADDRESS") +CA_FINGERPRINT = env("CA_FINGERPRINT") +CA_PORT = env("CA_PORT") + ALLOWED_HOSTS = env.list("ALLOWED_HOSTS", default=[]) # TODO Change later to list of allowed domains From 7cb16601d11203faea0030dedff654f18fcb0dc9 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 23 Apr 2024 03:06:51 +0200 Subject: [PATCH 040/242] update state logic in benchmarks for consistency --- server/benchmark/serializers.py | 5 ----- server/benchmarkdataset/serializers.py | 6 ------ server/benchmarkmodel/serializers.py | 6 ------ 3 files changed, 17 deletions(-) diff --git a/server/benchmark/serializers.py b/server/benchmark/serializers.py index 11f007d37..344635ade 100644 --- a/server/benchmark/serializers.py +++ b/server/benchmark/serializers.py @@ -57,11 +57,6 @@ def validate_approval_status(self, approval_status): raise serializers.ValidationError( "User can only approve or reject a benchmark" ) - if self.instance.state == "DEVELOPMENT": - raise serializers.ValidationError( - "User cannot approve or reject when benchmark is in development stage" - ) - if approval_status == "APPROVED": if self.instance.approval_status == "REJECTED": raise serializers.ValidationError( diff --git a/server/benchmarkdataset/serializers.py b/server/benchmarkdataset/serializers.py index fa7a16cdc..cbf4f6d00 100644 --- a/server/benchmarkdataset/serializers.py +++ b/server/benchmarkdataset/serializers.py @@ -21,13 +21,7 @@ def validate(self, data): dataset = self.context["request"].data.get("dataset") approval_status = self.context["request"].data.get("approval_status", "PENDING") - # benchmark state benchmark = Benchmark.objects.get(pk=bid) - benchmark_state = benchmark.state - if benchmark_state != "OPERATION": - raise serializers.ValidationError( - "Association requests can be made only on an operational benchmark" - ) # benchmark approval status benchmark_approval_status = benchmark.approval_status diff --git a/server/benchmarkmodel/serializers.py b/server/benchmarkmodel/serializers.py index 57cd1e2ab..9e13f612b 100644 --- a/server/benchmarkmodel/serializers.py +++ b/server/benchmarkmodel/serializers.py @@ -21,13 +21,7 @@ def validate(self, data): mlcube = self.context["request"].data.get("model_mlcube") approval_status = self.context["request"].data.get("approval_status", "PENDING") - # benchmark state benchmark = Benchmark.objects.get(pk=bid) - benchmark_state = benchmark.state - if benchmark_state != "OPERATION": - raise serializers.ValidationError( - "Association requests can be made only on an operational benchmark" - ) # benchmark approval status benchmark_approval_status = benchmark.approval_status From 0febf6b35fcdec7653416a546518df76a3448185 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 23 Apr 2024 03:36:39 +0200 Subject: [PATCH 041/242] create migrations and admin views --- server/aggregator/admin.py | 6 +- server/aggregator/migrations/0001_initial.py | 56 +++++++++++ server/aggregator_association/admin.py | 6 +- .../migrations/0001_initial.py | 65 +++++++++++++ .../migrations/0002_initial.py | 25 +++++ server/ca/admin.py | 6 +- server/ca/migrations/0001_initial.py | 48 +++++++++ server/ca/migrations/0002_createmedperfca.py | 29 ++++++ server/ca_association/admin.py | 6 +- .../ca_association/migrations/0001_initial.py | 64 ++++++++++++ .../ca_association/migrations/0002_initial.py | 25 +++++ server/medperf/settings.py | 11 +-- server/traindataset_association/admin.py | 6 +- .../migrations/0001_initial.py | 65 +++++++++++++ .../migrations/0002_initial.py | 25 +++++ server/training/admin.py | 9 +- server/training/migrations/0001_initial.py | 97 +++++++++++++++++++ server/trainingevent/admin.py | 6 +- .../trainingevent/migrations/0001_initial.py | 46 +++++++++ 19 files changed, 586 insertions(+), 15 deletions(-) create mode 100644 server/aggregator/migrations/0001_initial.py create mode 100644 server/aggregator_association/migrations/0001_initial.py create mode 100644 server/aggregator_association/migrations/0002_initial.py create mode 100644 server/ca/migrations/0001_initial.py create mode 100644 server/ca/migrations/0002_createmedperfca.py create mode 100644 server/ca_association/migrations/0001_initial.py create mode 100644 server/ca_association/migrations/0002_initial.py create mode 100644 server/traindataset_association/migrations/0001_initial.py create mode 100644 server/traindataset_association/migrations/0002_initial.py create mode 100644 server/training/migrations/0001_initial.py create mode 100644 server/trainingevent/migrations/0001_initial.py diff --git a/server/aggregator/admin.py b/server/aggregator/admin.py index 8c38f3f3d..ef24fb8a0 100644 --- a/server/aggregator/admin.py +++ b/server/aggregator/admin.py @@ -1,3 +1,7 @@ from django.contrib import admin +from .models import Aggregator -# Register your models here. + +@admin.register(Aggregator) +class AggregatorAdmin(admin.ModelAdmin): + list_display = [field.name for field in Aggregator._meta.fields] diff --git a/server/aggregator/migrations/0001_initial.py b/server/aggregator/migrations/0001_initial.py new file mode 100644 index 000000000..4b4ed8ec2 --- /dev/null +++ b/server/aggregator/migrations/0001_initial.py @@ -0,0 +1,56 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="Aggregator", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("address", models.CharField(max_length=300)), + ("port", models.IntegerField()), + ("metadata", models.JSONField(blank=True, default=dict, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "aggregation_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="aggregators", + to="mlcube.mlcube", + ), + ), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/aggregator_association/admin.py b/server/aggregator_association/admin.py index 8c38f3f3d..33bbd2b50 100644 --- a/server/aggregator_association/admin.py +++ b/server/aggregator_association/admin.py @@ -1,3 +1,7 @@ from django.contrib import admin +from .models import ExperimentAggregator -# Register your models here. + +@admin.register(ExperimentAggregator) +class ExperimentAggregatorAdmin(admin.ModelAdmin): + list_display = [field.name for field in ExperimentAggregator._meta.fields] diff --git a/server/aggregator_association/migrations/0001_initial.py b/server/aggregator_association/migrations/0001_initial.py new file mode 100644 index 000000000..5308cd5fb --- /dev/null +++ b/server/aggregator_association/migrations/0001_initial.py @@ -0,0 +1,65 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("aggregator", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentAggregator", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("metadata", models.JSONField(default=dict)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "aggregator", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to="aggregator.aggregator", + ), + ), + ( + "initiated_by", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/aggregator_association/migrations/0002_initial.py b/server/aggregator_association/migrations/0002_initial.py new file mode 100644 index 000000000..9f0f3f66c --- /dev/null +++ b/server/aggregator_association/migrations/0002_initial.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("training", "0001_initial"), + ("aggregator_association", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="experimentaggregator", + name="training_exp", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="training.trainingexperiment", + ), + ), + ] diff --git a/server/ca/admin.py b/server/ca/admin.py index 8c38f3f3d..04525f7d2 100644 --- a/server/ca/admin.py +++ b/server/ca/admin.py @@ -1,3 +1,7 @@ from django.contrib import admin +from .models import CA -# Register your models here. + +@admin.register(CA) +class CAAdmin(admin.ModelAdmin): + list_display = [field.name for field in CA._meta.fields] diff --git a/server/ca/migrations/0001_initial.py b/server/ca/migrations/0001_initial.py new file mode 100644 index 000000000..c58b3d149 --- /dev/null +++ b/server/ca/migrations/0001_initial.py @@ -0,0 +1,48 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="CA", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("address", models.CharField(max_length=300)), + ("port", models.IntegerField()), + ("fingerprint", models.TextField()), + ("metadata", models.JSONField(blank=True, default=dict, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/ca/migrations/0002_createmedperfca.py b/server/ca/migrations/0002_createmedperfca.py new file mode 100644 index 000000000..3c4a4f57f --- /dev/null +++ b/server/ca/migrations/0002_createmedperfca.py @@ -0,0 +1,29 @@ +from django.contrib.auth import get_user_model +from django.db import migrations +from django.db.backends.postgresql.schema import DatabaseSchemaEditor +from django.db.migrations.state import StateApps +from django.conf import settings +from ca.models import CA + +User = get_user_model() + + +def createmedperfca(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> None: + """ + Dynamically create the configured main CA as part of a migration + """ + admin_user = User.objects.get(username=settings.SUPERUSER_USERNAME) + CA.objects.create( + name=settings.CA_NAME, + address=settings.CA_ADDRESS, + port=settings.CA_PORT, + fingerprint=settings.CA_FINGERPRINT, + owner=admin_user, + ) + + +class Migration(migrations.Migration): + + initial = True + dependencies = [("ca", "0001_initial"), ("user", "0001_createsuperuser")] + operations = [migrations.RunPython(createmedperfca)] diff --git a/server/ca_association/admin.py b/server/ca_association/admin.py index 8c38f3f3d..3317c359c 100644 --- a/server/ca_association/admin.py +++ b/server/ca_association/admin.py @@ -1,3 +1,7 @@ from django.contrib import admin +from .models import ExperimentCA -# Register your models here. + +@admin.register(ExperimentCA) +class ExperimentCAAdmin(admin.ModelAdmin): + list_display = [field.name for field in ExperimentCA._meta.fields] diff --git a/server/ca_association/migrations/0001_initial.py b/server/ca_association/migrations/0001_initial.py new file mode 100644 index 000000000..9de3ccbee --- /dev/null +++ b/server/ca_association/migrations/0001_initial.py @@ -0,0 +1,64 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("ca", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentCA", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("metadata", models.JSONField(default=dict)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "ca", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, to="ca.ca" + ), + ), + ( + "initiated_by", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] diff --git a/server/ca_association/migrations/0002_initial.py b/server/ca_association/migrations/0002_initial.py new file mode 100644 index 000000000..d56842a1f --- /dev/null +++ b/server/ca_association/migrations/0002_initial.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("ca_association", "0001_initial"), + ("training", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="experimentca", + name="training_exp", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="training.trainingexperiment", + ), + ), + ] diff --git a/server/medperf/settings.py b/server/medperf/settings.py index f77dd385e..53df5b5f4 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -16,8 +16,6 @@ import environ import google.auth from google.cloud import secretmanager -from key_storage.gcloud_secret_manager import GcloudSecretStorage -from key_storage.local import LocalSecretStorage # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent @@ -100,8 +98,11 @@ "result", "training", "aggregator", + "ca", "traindataset_association", "aggregator_association", + "ca_association", + "trainingevent", "rest_framework", "rest_framework.authtoken", "drf_spectacular", @@ -299,9 +300,3 @@ "JTI_CLAIM": None, # Currently expected auth tokens don't contain such a claim } TOKEN_USER_EMAIL_CLAIM = "https://medperf.org/email" - -if DEPLOY_ENV == "gcp-prod": - # TODO - KEY_STORAGE = GcloudSecretStorage("") -else: - KEY_STORAGE = LocalSecretStorage(os.path.join(BASE_DIR, "keys")) diff --git a/server/traindataset_association/admin.py b/server/traindataset_association/admin.py index 8c38f3f3d..13119fb54 100644 --- a/server/traindataset_association/admin.py +++ b/server/traindataset_association/admin.py @@ -1,3 +1,7 @@ from django.contrib import admin +from .models import ExperimentDataset -# Register your models here. + +@admin.register(ExperimentDataset) +class ExperimentDatasetAdmin(admin.ModelAdmin): + list_display = [field.name for field in ExperimentDataset._meta.fields] diff --git a/server/traindataset_association/migrations/0001_initial.py b/server/traindataset_association/migrations/0001_initial.py new file mode 100644 index 000000000..ac68ba155 --- /dev/null +++ b/server/traindataset_association/migrations/0001_initial.py @@ -0,0 +1,65 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("dataset", "0004_auto_20231211_1827"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentDataset", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("metadata", models.JSONField(default=dict)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "dataset", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to="dataset.dataset", + ), + ), + ( + "initiated_by", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["modified_at"], + }, + ), + ] diff --git a/server/traindataset_association/migrations/0002_initial.py b/server/traindataset_association/migrations/0002_initial.py new file mode 100644 index 000000000..ec851bd0b --- /dev/null +++ b/server/traindataset_association/migrations/0002_initial.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("traindataset_association", "0001_initial"), + ("training", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="experimentdataset", + name="training_exp", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="training.trainingexperiment", + ), + ), + ] diff --git a/server/training/admin.py b/server/training/admin.py index 8c38f3f3d..27dc93261 100644 --- a/server/training/admin.py +++ b/server/training/admin.py @@ -1,3 +1,10 @@ from django.contrib import admin -# Register your models here. +from .models import TrainingExperiment + + +class TrainingExperimentAdmin(admin.ModelAdmin): + list_display = [field.name for field in TrainingExperiment._meta.fields] + + +admin.site.register(TrainingExperiment, TrainingExperimentAdmin) diff --git a/server/training/migrations/0001_initial.py b/server/training/migrations/0001_initial.py new file mode 100644 index 000000000..c4fae0e23 --- /dev/null +++ b/server/training/migrations/0001_initial.py @@ -0,0 +1,97 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="TrainingExperiment", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=20, unique=True)), + ("description", models.CharField(blank=True, max_length=100)), + ("docs_url", models.CharField(blank=True, max_length=100)), + ("demo_dataset_tarball_url", models.CharField(max_length=256)), + ("demo_dataset_tarball_hash", models.CharField(max_length=100)), + ("demo_dataset_generated_uid", models.CharField(max_length=128)), + ("metadata", models.JSONField(blank=True, default=dict, null=True)), + ( + "state", + models.CharField( + choices=[ + ("DEVELOPMENT", "DEVELOPMENT"), + ("OPERATION", "OPERATION"), + ], + default="DEVELOPMENT", + max_length=100, + ), + ), + ("is_valid", models.BooleanField(default=True)), + ( + "approval_status", + models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("APPROVED", "APPROVED"), + ("REJECTED", "REJECTED"), + ], + default="PENDING", + max_length=100, + ), + ), + ("plan", models.JSONField(blank=True, null=True)), + ( + "user_metadata", + models.JSONField(blank=True, default=dict, null=True), + ), + ("approved_at", models.DateTimeField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ( + "data_preparation_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="training_exp", + to="mlcube.mlcube", + ), + ), + ( + "fl_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="fl_mlcube", + to="mlcube.mlcube", + ), + ), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["modified_at"], + }, + ), + ] diff --git a/server/trainingevent/admin.py b/server/trainingevent/admin.py index 8c38f3f3d..a6e5ca6af 100644 --- a/server/trainingevent/admin.py +++ b/server/trainingevent/admin.py @@ -1,3 +1,7 @@ from django.contrib import admin +from .models import TrainingEvent -# Register your models here. + +@admin.register(TrainingEvent) +class TrainingEventAdmin(admin.ModelAdmin): + list_display = [field.name for field in TrainingEvent._meta.fields] diff --git a/server/trainingevent/migrations/0001_initial.py b/server/trainingevent/migrations/0001_initial.py new file mode 100644 index 000000000..894b79121 --- /dev/null +++ b/server/trainingevent/migrations/0001_initial.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.11 on 2024-04-23 01:12 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("training", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="TrainingEvent", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("finished", models.BooleanField(default=False)), + ("participants", models.JSONField()), + ("report", models.JSONField(blank=True, null=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("finished_at", models.DateTimeField(blank=True, null=True)), + ( + "training_exp", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="events", + to="training.trainingexperiment", + ), + ), + ], + options={ + "ordering": ["created_at"], + }, + ), + ] From c479c3e508e0fa192e5d5b350adf8c7cf4b6fe6e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 26 Apr 2024 16:04:02 +0200 Subject: [PATCH 042/242] update fl mlcube interface --- examples/fl/fl/clean.sh | 4 + examples/fl/fl/mlcube/mlcube.yaml | 13 +- .../fl/fl/mlcube/workspace/parameters.yaml | 166 ------------------ .../fl/mlcube/workspace/training_config.yaml | 165 +++++++++++++++++ examples/fl/fl/project/Dockerfile | 2 +- examples/fl/fl/project/README.md | 2 +- examples/fl/fl/project/aggregator.py | 5 +- examples/fl/fl/project/collaborator.py | 5 +- examples/fl/fl/project/hooks.py | 12 +- examples/fl/fl/project/mlcube.py | 37 ++-- examples/fl/fl/project/plan.py | 16 ++ examples/fl/fl/project/utils.py | 21 ++- examples/fl/fl/setup_test.sh | 69 +++++--- examples/fl/fl/sync.sh | 5 +- examples/fl/fl/test.sh | 25 ++- 15 files changed, 305 insertions(+), 242 deletions(-) delete mode 100644 examples/fl/fl/mlcube/workspace/parameters.yaml create mode 100644 examples/fl/fl/mlcube/workspace/training_config.yaml create mode 100644 examples/fl/fl/project/plan.py diff --git a/examples/fl/fl/clean.sh b/examples/fl/fl/clean.sh index cc7bdc725..ce7879606 100644 --- a/examples/fl/fl/clean.sh +++ b/examples/fl/fl/clean.sh @@ -3,3 +3,7 @@ rm -rf mlcube_agg/workspace/logs rm -rf mlcube_col1/workspace/logs rm -rf mlcube_col2/workspace/logs rm -rf mlcube_col3/workspace/logs +rm -rf mlcube_agg/workspace/plan.yaml +rm -rf mlcube_col1/workspace/plan.yaml +rm -rf mlcube_col2/workspace/plan.yaml +rm -rf mlcube_col3/workspace/plan.yaml diff --git a/examples/fl/fl/mlcube/mlcube.yaml b/examples/fl/fl/mlcube/mlcube.yaml index f7a67b805..639a81118 100644 --- a/examples/fl/fl/mlcube/mlcube.yaml +++ b/examples/fl/fl/mlcube/mlcube.yaml @@ -22,8 +22,7 @@ tasks: labels_path: labels/ node_cert_folder: node_cert/ ca_cert_folder: ca_cert/ - parameters_file: parameters.yaml - network_config: network.yaml + plan_path: plan.yaml outputs: output_logs: logs/ start_aggregator: @@ -32,9 +31,15 @@ tasks: input_weights: additional_files/init_weights node_cert_folder: node_cert/ ca_cert_folder: ca_cert/ - parameters_file: parameters.yaml - network_config: network.yaml + plan_path: plan.yaml collaborators: cols.yaml outputs: output_logs: logs/ output_weights: final_weights/ + generate_plan: + parameters: + inputs: + training_config_path: training_config.yaml + aggregator_config_path: aggregator_config.yaml + outputs: + plan_path: { type: "file", default: "plan/plan.yaml" } diff --git a/examples/fl/fl/mlcube/workspace/parameters.yaml b/examples/fl/fl/mlcube/workspace/parameters.yaml deleted file mode 100644 index 04acd542c..000000000 --- a/examples/fl/fl/mlcube/workspace/parameters.yaml +++ /dev/null @@ -1,166 +0,0 @@ -plan: - aggregator: - settings: - best_state_path: save/classification_best.pbuf - db_store_rounds: 2 - init_state_path: save/classification_init.pbuf - last_state_path: save/classification_last.pbuf - rounds_to_train: 2 - write_logs: true - template: openfl.component.Aggregator - assigner: - settings: - task_groups: - - name: train_and_validate - percentage: 1.0 - tasks: - - aggregated_model_validation - - train - - locally_tuned_model_validation - template: openfl.component.RandomGroupedAssigner - collaborator: - settings: - db_store_rounds: 1 - delta_updates: false - opt_treatment: RESET - template: openfl.component.Collaborator - compression_pipeline: - settings: {} - template: openfl.pipelines.NoCompressionPipeline - data_loader: - settings: - feature_shape: - - 128 - - 128 - template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper - network: - settings: - cert_folder: cert - client_reconnect_interval: 5 - disable_client_auth: false - hash_salt: auto - tls: true - template: openfl.federation.Network - task_runner: - settings: - device: cpu - gandlf_config: - memory_save_mode: false # - batch_size: 16 - clip_grad: null - clip_mode: null - data_augmentation: {} - data_postprocessing: {} - data_preprocessing: - resize: - - 128 - - 128 - enable_padding: false - grid_aggregator_overlap: crop - in_memory: false - inference_mechanism: - grid_aggregator_overlap: crop - patch_overlap: 0 - learning_rate: 0.001 - loss_function: cel - medcam_enabled: false - metrics: - accuracy: - average: weighted - mdmc_average: samplewise - multi_class: true - subset_accuracy: false - threshold: 0.5 - balanced_accuracy: None - classification_accuracy: None - f1: - average: weighted - f1: - average: weighted - mdmc_average: samplewise - multi_class: true - threshold: 0.5 - modality: rad - model: - amp: false - architecture: resnet18 - base_filters: 32 - batch_norm: true - class_list: - - 0 - - 1 - - 2 - - 3 - - 4 - - 5 - - 6 - - 7 - - 8 - dimension: 2 - final_layer: sigmoid - ignore_label_validation: None - n_channels: 3 - norm_type: batch - num_channels: 3 - save_at_every_epoch: false - type: torch - nested_training: - testing: 1 - validation: -5 - num_epochs: 2 - opt: adam - optimizer: - type: adam - output_dir: . - parallel_compute_command: "" - patch_sampler: uniform - patch_size: - - 128 - - 128 - - 1 - patience: 1 - pin_memory_dataloader: false - print_rgb_label_warning: true - q_max_length: 5 - q_num_workers: 0 - q_samples_per_volume: 1 - q_verbose: false - save_masks: false - save_output: false - save_training: false - scaling_factor: 1 - scheduler: - step_size: 0.0002 - type: triangle - track_memory_usage: false - verbose: false - version: - maximum: 0.0.19 - minimum: 0.0.19 - weighted_loss: true - train_csv: train_path_full.csv - val_csv: val_path_full.csv - template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner - tasks: - aggregated_model_validation: - function: validate - kwargs: - apply: global - metrics: - - valid_loss - - valid_accuracy - locally_tuned_model_validation: - function: validate - kwargs: - apply: local - metrics: - - valid_loss - - valid_accuracy - settings: {} - train: - function: train - kwargs: - epochs: 1 - metrics: - - loss - - train_accuracy diff --git a/examples/fl/fl/mlcube/workspace/training_config.yaml b/examples/fl/fl/mlcube/workspace/training_config.yaml new file mode 100644 index 000000000..e5ba18e21 --- /dev/null +++ b/examples/fl/fl/mlcube/workspace/training_config.yaml @@ -0,0 +1,165 @@ +aggregator: + settings: + best_state_path: save/classification_best.pbuf + db_store_rounds: 2 + init_state_path: save/classification_init.pbuf + last_state_path: save/classification_last.pbuf + rounds_to_train: 2 + write_logs: true + template: openfl.component.Aggregator +assigner: + settings: + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner +collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator +compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline +data_loader: + settings: + feature_shape: + - 128 + - 128 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper +network: + settings: + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network +task_runner: + settings: + device: cpu + gandlf_config: + memory_save_mode: false # + batch_size: 16 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + resize: + - 128 + - 128 + enable_padding: false + grid_aggregator_overlap: crop + in_memory: false + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: cel + medcam_enabled: false + metrics: + accuracy: + average: weighted + mdmc_average: samplewise + multi_class: true + subset_accuracy: false + threshold: 0.5 + balanced_accuracy: None + classification_accuracy: None + f1: + average: weighted + f1: + average: weighted + mdmc_average: samplewise + multi_class: true + threshold: 0.5 + modality: rad + model: + amp: false + architecture: resnet18 + base_filters: 32 + batch_norm: true + class_list: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + dimension: 2 + final_layer: sigmoid + ignore_label_validation: None + n_channels: 3 + norm_type: batch + num_channels: 3 + save_at_every_epoch: false + type: torch + nested_training: + testing: 1 + validation: -5 + num_epochs: 2 + opt: adam + optimizer: + type: adam + output_dir: . + parallel_compute_command: "" + patch_sampler: uniform + patch_size: + - 128 + - 128 + - 1 + patience: 1 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 5 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_masks: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + step_size: 0.0002 + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.20-dev + minimum: 0.0.20-dev + weighted_loss: true + train_csv: train_path_full.csv + val_csv: val_path_full.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner +tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_accuracy + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_accuracy + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_accuracy \ No newline at end of file diff --git a/examples/fl/fl/project/Dockerfile b/examples/fl/fl/project/Dockerfile index 90d680efc..9d7877a02 100644 --- a/examples/fl/fl/project/Dockerfile +++ b/examples/fl/fl/project/Dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev RUN pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 && \ pip install --no-cache-dir openvino-dev==2023.0.1 && \ git clone https://github.com/mlcommons/GaNDLF.git && \ - cd GaNDLF && git checkout 64962a1d3416071299a452126baccd3163f0b2d8 && \ + cd GaNDLF && git checkout dd88b8883cb0e57a0ac615e9cb5be7416d0dada4 && \ pip install --no-cache-dir -e . COPY ./requirements.txt /mlcube_project/requirements.txt diff --git a/examples/fl/fl/project/README.md b/examples/fl/fl/project/README.md index 21ad970e9..1e348651b 100644 --- a/examples/fl/fl/project/README.md +++ b/examples/fl/fl/project/README.md @@ -15,8 +15,8 @@ ```bash git clone https://github.com/securefederatedai/openfl.git -git checkout e6f3f5fd4462307b2c9431184190167aa43d962f cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl/fl/project/aggregator.py b/examples/fl/fl/project/aggregator.py index fd7e60a5d..36d8b02f1 100644 --- a/examples/fl/fl/project/aggregator.py +++ b/examples/fl/fl/project/aggregator.py @@ -17,18 +17,17 @@ def start_aggregator( input_weights, - parameters_file, node_cert_folder, ca_cert_folder, output_logs, output_weights, - network_config, + plan_path, collaborators, ): workspace_folder = os.path.join(output_logs, "workspace") create_workspace(workspace_folder) - prepare_plan(parameters_file, network_config, workspace_folder) + prepare_plan(plan_path, workspace_folder) prepare_cols_list(collaborators, workspace_folder) prepare_init_weights(input_weights, workspace_folder) fqdn = get_aggregator_fqdn(workspace_folder) diff --git a/examples/fl/fl/project/collaborator.py b/examples/fl/fl/project/collaborator.py index 0762f5b82..38c5048b6 100644 --- a/examples/fl/fl/project/collaborator.py +++ b/examples/fl/fl/project/collaborator.py @@ -13,15 +13,14 @@ def start_collaborator( data_path, labels_path, - parameters_file, node_cert_folder, ca_cert_folder, - network_config, + plan_path, output_logs, ): workspace_folder = os.path.join(output_logs, "workspace") create_workspace(workspace_folder) - prepare_plan(parameters_file, network_config, workspace_folder) + prepare_plan(plan_path, workspace_folder) cn = get_collaborator_cn() prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) prepare_ca_cert(ca_cert_folder, workspace_folder) diff --git a/examples/fl/fl/project/hooks.py b/examples/fl/fl/project/hooks.py index dfa4792b5..5b124e8b0 100644 --- a/examples/fl/fl/project/hooks.py +++ b/examples/fl/fl/project/hooks.py @@ -26,10 +26,9 @@ def __modify_df(df): def collaborator_pre_training_hook( data_path, labels_path, - parameters_file, node_cert_folder, ca_cert_folder, - network_config, + plan_path, output_logs, ): cn = get_collaborator_cn() @@ -66,10 +65,9 @@ def collaborator_pre_training_hook( def collaborator_post_training_hook( data_path, labels_path, - parameters_file, node_cert_folder, ca_cert_folder, - network_config, + plan_path, output_logs, ): pass @@ -77,12 +75,11 @@ def collaborator_post_training_hook( def aggregator_pre_training_hook( input_weights, - parameters_file, node_cert_folder, ca_cert_folder, output_logs, output_weights, - network_config, + plan_path, collaborators, ): pass @@ -90,12 +87,11 @@ def aggregator_pre_training_hook( def aggregator_post_training_hook( input_weights, - parameters_file, node_cert_folder, ca_cert_folder, output_logs, output_weights, - network_config, + plan_path, collaborators, ): pass diff --git a/examples/fl/fl/project/mlcube.py b/examples/fl/fl/project/mlcube.py index d88a8eb29..f3cfa640d 100644 --- a/examples/fl/fl/project/mlcube.py +++ b/examples/fl/fl/project/mlcube.py @@ -5,6 +5,7 @@ import typer from collaborator import start_collaborator from aggregator import start_aggregator +from plan import generate_plan from hooks import ( aggregator_pre_training_hook, aggregator_post_training_hook, @@ -31,38 +32,34 @@ def _teardown(output_logs): def train( data_path: str = typer.Option(..., "--data_path"), labels_path: str = typer.Option(..., "--labels_path"), - parameters_file: str = typer.Option(..., "--parameters_file"), node_cert_folder: str = typer.Option(..., "--node_cert_folder"), ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), - network_config: str = typer.Option(..., "--network_config"), + plan_path: str = typer.Option(..., "--plan_path"), output_logs: str = typer.Option(..., "--output_logs"), ): _setup(output_logs) collaborator_pre_training_hook( data_path=data_path, labels_path=labels_path, - parameters_file=parameters_file, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, - network_config=network_config, + plan_path=plan_path, output_logs=output_logs, ) start_collaborator( data_path=data_path, labels_path=labels_path, - parameters_file=parameters_file, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, - network_config=network_config, + plan_path=plan_path, output_logs=output_logs, ) collaborator_post_training_hook( data_path=data_path, labels_path=labels_path, - parameters_file=parameters_file, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, - network_config=network_config, + plan_path=plan_path, output_logs=output_logs, ) _teardown(output_logs) @@ -71,47 +68,55 @@ def train( @app.command("start_aggregator") def start_aggregator_( input_weights: str = typer.Option(..., "--input_weights"), - parameters_file: str = typer.Option(..., "--parameters_file"), node_cert_folder: str = typer.Option(..., "--node_cert_folder"), ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), output_logs: str = typer.Option(..., "--output_logs"), output_weights: str = typer.Option(..., "--output_weights"), - network_config: str = typer.Option(..., "--network_config"), + plan_path: str = typer.Option(..., "--plan_path"), collaborators: str = typer.Option(..., "--collaborators"), ): _setup(output_logs) aggregator_pre_training_hook( input_weights=input_weights, - parameters_file=parameters_file, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, output_logs=output_logs, output_weights=output_weights, - network_config=network_config, + plan_path=plan_path, collaborators=collaborators, ) start_aggregator( input_weights=input_weights, - parameters_file=parameters_file, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, output_logs=output_logs, output_weights=output_weights, - network_config=network_config, + plan_path=plan_path, collaborators=collaborators, ) aggregator_post_training_hook( input_weights=input_weights, - parameters_file=parameters_file, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, output_logs=output_logs, output_weights=output_weights, - network_config=network_config, + plan_path=plan_path, collaborators=collaborators, ) _teardown(output_logs) +@app.command("generate_plan") +def generate_plan_( + training_config_path: str = typer.Option(..., "--training_config_path"), + aggregator_config_path: str = typer.Option(..., "--aggregator_config_path"), + plan_path: str = typer.Option(..., "--plan_path"), +): + # no _setup here since there is no writable output mounted volume. + # later if need this we think of a solution. Currently the create_plam + # logic is assumed to not write within the container. + generate_plan(training_config_path, aggregator_config_path, plan_path) + + if __name__ == "__main__": app() diff --git a/examples/fl/fl/project/plan.py b/examples/fl/fl/project/plan.py new file mode 100644 index 000000000..2feb1bf52 --- /dev/null +++ b/examples/fl/fl/project/plan.py @@ -0,0 +1,16 @@ +import yaml + + +def generate_plan(training_config_path, aggregator_config_path, plan_path): + with open(training_config_path) as f: + training_config = yaml.safe_load(f) + with open(aggregator_config_path) as f: + aggregator_config = yaml.safe_load(f) + + # TODO: key checks. Also, define what should be considered aggregator_config + # (e.g., tls=true, reconnect_interval, ...) + training_config["network"]["settings"]["agg_addr"] = aggregator_config["address"] + training_config["network"]["settings"]["agg_port"] = aggregator_config["port"] + + with open(plan_path, "w") as f: + yaml.dump(training_config, f) diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py index a7d2a0851..558d0f8dd 100644 --- a/examples/fl/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -1,5 +1,6 @@ import yaml import os +import shutil def create_workspace(fl_workspace): @@ -36,21 +37,13 @@ def get_weights_path(fl_workspace): } -def prepare_plan(parameters_file, network_config, fl_workspace): - with open(parameters_file) as f: - params = yaml.safe_load(f) - if "plan" not in params: - raise RuntimeError("Parameters file should contain a `plan` entry") - with open(network_config) as f: - network_config_dict = yaml.safe_load(f) - plan = params["plan"] - plan["network"]["settings"].update(network_config_dict) +def prepare_plan(plan_path, fl_workspace): target_plan_folder = os.path.join(fl_workspace, "plan") # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") - with open(target_plan_file, "w") as f: - yaml.dump(plan, f) + shutil.copyfile(plan_path, target_plan_file) def prepare_cols_list(collaborators_file, fl_workspace): @@ -58,12 +51,18 @@ def prepare_cols_list(collaborators_file, fl_workspace): cols = f.read().strip().split("\n") cols = [col.strip().split(",") for col in cols] cols_dict = {} + cn_different = False for col in cols: if len(col) == 1: cols_dict[col[0]] = col[0] else: assert len(col) == 2 cols_dict[col[0]] = col[1] + if col[0] != col[1]: + cn_different = True + if not cn_different: + # quick hack to support old and new openfl versions + cols_dict = list(cols_dict.keys()) target_plan_folder = os.path.join(fl_workspace, "plan") # TODO: permissions diff --git a/examples/fl/fl/setup_test.sh b/examples/fl/fl/setup_test.sh index f728883b3..75a3d68f5 100644 --- a/examples/fl/fl/setup_test.sh +++ b/examples/fl/fl/setup_test.sh @@ -1,3 +1,23 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + cp -r ./mlcube ./mlcube_agg cp -r ./mlcube ./mlcube_col1 cp -r ./mlcube ./mlcube_col2 @@ -17,8 +37,8 @@ openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root. -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" # col1 -sed -i '/^commonName = /c\commonName = col1@example.com' csr.conf -sed -i '/^DNS\.1 = /c\DNS.1 = col1@example.com' csr.conf +sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf cd mlcube_col1/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client @@ -29,13 +49,9 @@ cp ../../../ca/root.crt ../ca_cert/ cd ../../../ # col2 -cp mlcube_col1/workspace/node_cert/* mlcube_col2/workspace/node_cert -cp mlcube_col1/workspace/ca_cert/* mlcube_col2/workspace/ca_cert - -# col3 -sed -i '/^commonName = /c\commonName = col3@example.com' csr.conf -sed -i '/^DNS\.1 = /c\DNS.1 = col3@example.com' csr.conf -cd mlcube_col3/workspace/node_cert +sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf +cd mlcube_col2/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ @@ -44,6 +60,23 @@ rm csr.csr cp ../../../ca/root.crt ../ca_cert/ cd ../../../ +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf + cd mlcube_col3/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ +fi + # agg sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf @@ -56,20 +89,14 @@ rm csr.csr cp ../../../ca/root.crt ../ca_cert/ cd ../../../ -# network file -echo "agg_addr: $HOSTNAME_" >>mlcube_col1/workspace/network.yaml -echo "agg_port: 50273" >>mlcube_col1/workspace/network.yaml -echo "address: $HOSTNAME_" >>mlcube_col1/workspace/network.yaml -echo "port: 50273" >>mlcube_col1/workspace/network.yaml - -cp mlcube_col1/workspace/network.yaml mlcube_col2/workspace/network.yaml -cp mlcube_col1/workspace/network.yaml mlcube_agg/workspace/network.yaml -cp mlcube_col1/workspace/network.yaml mlcube_col3/workspace/network.yaml +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml # cols file -echo "abc,col1@example.com" >>mlcube_agg/workspace/cols.yaml -echo "defg,col1@example.com" >>mlcube_agg/workspace/cols.yaml -echo "hij,col3@example.com" >>mlcube_agg/workspace/cols.yaml +echo "$COL1_LABEL,$COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL,$COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL,$COL3_CN" >>mlcube_agg/workspace/cols.yaml # data download cd mlcube_col1/workspace/ diff --git a/examples/fl/fl/sync.sh b/examples/fl/fl/sync.sh index d454a9c2f..a5375ce54 100644 --- a/examples/fl/fl/sync.sh +++ b/examples/fl/fl/sync.sh @@ -1,7 +1,4 @@ -cp mlcube/workspace/parameters.yaml mlcube_agg/workspace/parameters.yaml -cp mlcube/workspace/parameters.yaml mlcube_col1/workspace/parameters.yaml -cp mlcube/workspace/parameters.yaml mlcube_col2/workspace/parameters.yaml -cp mlcube/workspace/parameters.yaml mlcube_col3/workspace/parameters.yaml +cp mlcube/workspace/training_config.yaml mlcube_agg/workspace/training_config.yaml cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh index 1b751ce37..c6a77ffed 100644 --- a/examples/fl/fl/test.sh +++ b/examples/fl/fl/test.sh @@ -1,4 +1,21 @@ -gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273; bash" -gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col1 --task train -e COLLABORATOR_CN=abc; bash" -gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col2 --task train -e COLLABORATOR_CN=defg; bash" -gnome-terminal -- bash -c "medperf mlcube run --mlcube ./mlcube_col3 --task train -e COLLABORATOR_CN=hij; bash" +# generate plan and copy it to each node +medperf mlcube run --mlcube ./mlcube_agg --task generate_plan +mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +rm -r ./mlcube_agg/workspace/plan +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace + +# Run nodes +AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e COLLABORATOR_CN=col1@example.com" +COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e COLLABORATOR_CN=col2@example.com" +COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e COLLABORATOR_CN=col3@example.com" + +gnome-terminal -- bash -c "$AGG; bash" +gnome-terminal -- bash -c "$COL1; bash" +gnome-terminal -- bash -c "$COL2; bash" +gnome-terminal -- bash -c "$COL3; bash" + +# docker run --env COLLABORATOR_CN=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 From 48ccd60e008a010fb024c67f1ca6afba82744b7b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Sat, 27 Apr 2024 07:21:32 +0200 Subject: [PATCH 043/242] refactor entities --- cli/cli_tests.sh | 4 +- cli/medperf/commands/benchmark/benchmark.py | 16 +- .../compatibility_test/compatibility_test.py | 8 +- .../commands/compatibility_test/utils.py | 16 +- cli/medperf/commands/dataset/dataset.py | 16 +- cli/medperf/commands/list.py | 14 +- cli/medperf/commands/mlcube/mlcube.py | 16 +- cli/medperf/commands/result/create.py | 5 +- cli/medperf/commands/result/result.py | 18 +- cli/medperf/commands/view.py | 12 +- cli/medperf/entities/benchmark.py | 212 ++---------------- cli/medperf/entities/cube.py | 161 +++---------- cli/medperf/entities/dataset.py | 187 +++------------ cli/medperf/entities/interface.py | 206 ++++++++++++++--- cli/medperf/entities/report.py | 92 ++------ cli/medperf/entities/result.py | 182 ++------------- cli/medperf/entities/schemas.py | 4 +- .../tests/commands/result/test_create.py | 3 + cli/medperf/tests/commands/test_list.py | 8 +- cli/medperf/tests/commands/test_view.py | 211 +++++++---------- cli/medperf/tests/entities/test_benchmark.py | 5 +- cli/medperf/tests/entities/test_cube.py | 13 +- cli/medperf/tests/entities/test_entity.py | 73 +++--- cli/medperf/tests/entities/utils.py | 85 ++++--- 24 files changed, 560 insertions(+), 1007 deletions(-) diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index ac6137b65..68764618a 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -5,7 +5,6 @@ ################### Start Testing ######################## ########################################################## - ########################################################## echo "==========================================" echo "Printing MedPerf version" @@ -186,7 +185,7 @@ echo "Running data submission step" echo "=====================================" medperf dataset submit -p $PREP_UID -d $DIRECTORY/dataset_a -l $DIRECTORY/dataset_a --name="dataset_a" --description="mock dataset a" --location="mock location a" -y checkFailed "Data submission step failed" -DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | cut -d ' ' -f 1) +DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## echo "\n" @@ -212,7 +211,6 @@ DSET_A_GENUID=$(medperf dataset view $DSET_A_UID | grep generated_uid | cut -d " echo "\n" - ########################################################## echo "=====================================" echo "Moving storage to some other location" diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index f02d67cb4..35d719b0d 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local benchmarks"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered benchmarks" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), ): - """List benchmarks stored locally and remotely from the user""" + """List benchmarks""" EntityList.run( Benchmark, fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -162,10 +164,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( + unregistered: bool = typer.Option( False, - "--local", - help="Display local benchmarks if benchmark ID is not provided", + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", ), mine: bool = typer.Option( False, @@ -180,4 +182,4 @@ def view( ), ): """Displays the information of one or more benchmarks""" - EntityView.run(entity_id, Benchmark, format, local, mine, output) + EntityView.run(entity_id, Benchmark, format, unregistered, mine, output) diff --git a/cli/medperf/commands/compatibility_test/compatibility_test.py b/cli/medperf/commands/compatibility_test/compatibility_test.py index a3b25ac78..0bd4a4695 100644 --- a/cli/medperf/commands/compatibility_test/compatibility_test.py +++ b/cli/medperf/commands/compatibility_test/compatibility_test.py @@ -95,7 +95,11 @@ def run( @clean_except def list(): """List previously executed tests reports.""" - EntityList.run(TestReport, fields=["UID", "Data Source", "Model", "Evaluator"]) + EntityList.run( + TestReport, + fields=["UID", "Data Source", "Model", "Evaluator"], + unregistered=True, + ) @app.command("view") @@ -116,4 +120,4 @@ def view( ), ): """Displays the information of one or more test reports""" - EntityView.run(entity_id, TestReport, format, output=output) + EntityView.run(entity_id, TestReport, format, unregistered=True, output=output) diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index a12ac5ea2..c56a57d41 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -138,23 +138,23 @@ def create_test_dataset( # TODO: existing dataset could make problems # make some changes since this is a test dataset config.tmp_paths.remove(data_creation.dataset.path) - data_creation.dataset.write() if skip_data_preparation_step: data_creation.make_dataset_prepared() dataset = data_creation.dataset + old_generated_uid = dataset.generated_uid + old_path = dataset.path # prepare/check dataset DataPreparation.run(dataset.generated_uid) # update dataset generated_uid - old_path = dataset.path - generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) - dataset.generated_uid = generated_uid - dataset.write() - if dataset.input_data_hash != dataset.generated_uid: + new_generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) + if new_generated_uid != old_generated_uid: # move to a correct location if it underwent preparation - new_path = old_path.replace(dataset.input_data_hash, generated_uid) + new_path = old_path.replace(old_generated_uid, new_generated_uid) remove_path(new_path) os.rename(old_path, new_path) + dataset.generated_uid = new_generated_uid + dataset.write() - return generated_uid + return new_generated_uid diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index a27e36814..fc18022ac 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -17,17 +17,19 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local datasets"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered datasets" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"), mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), ): - """List datasets stored locally and remotely from the user""" + """List datasets""" EntityList.run( Dataset, fields=["UID", "Name", "Data Preparation Cube UID", "State", "Status", "Owner"], - local_only=local, + unregistered=unregistered, mine_only=mine, mlcube=mlcube, ) @@ -149,8 +151,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local datasets if dataset ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered datasets if dataset ID is not provided", ), mine: bool = typer.Option( False, @@ -165,4 +169,4 @@ def view( ), ): """Displays the information of one or more datasets""" - EntityView.run(entity_id, Dataset, format, local, mine, output) + EntityView.run(entity_id, Dataset, format, unregistered, mine, output) diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 5fd462bf7..b5d6226a4 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -10,27 +10,29 @@ class EntityList: def run( entity_class, fields, - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, **kwargs, ): """Lists all local datasets Args: - local_only (bool, optional): Display all local results. Defaults to False. + unregistered (bool, optional): Display only local unregistered results. Defaults to False. mine_only (bool, optional): Display all current-user results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ - entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs) + entity_list = EntityList( + entity_class, fields, unregistered, mine_only, **kwargs + ) entity_list.prepare() entity_list.validate() entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, local_only, mine_only, **kwargs): + def __init__(self, entity_class, fields, unregistered, mine_only, **kwargs): self.entity_class = entity_class self.fields = fields - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.filters = kwargs self.data = [] @@ -40,7 +42,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.display_dict() for entity in entities] diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 7b53ce940..bad358f8e 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -41,14 +41,16 @@ def run( @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local mlcubes"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered mlcubes" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), ): - """List mlcubes stored locally and remotely from the user""" + """List mlcubes""" EntityList.run( Cube, fields=["UID", "Name", "State", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -173,8 +175,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local mlcubes if mlcube ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered mlcubes if mlcube ID is not provided", ), mine: bool = typer.Option( False, @@ -189,4 +193,4 @@ def view( ), ): """Displays the information of one or more mlcubes""" - EntityView.run(entity_id, Cube, format, local, mine, output) + EntityView.run(entity_id, Cube, format, unregistered, mine, output) diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 42f97d990..760dddc94 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from medperf.account_management.account_management import get_medperf_user_data from medperf.commands.execution import Execution from medperf.entities.result import Result from tabulate import tabulate @@ -143,7 +144,9 @@ def __validate_models(self, benchmark_models): raise InvalidArgumentError(msg) def load_cached_results(self): - results = Result.all() + user_id = get_medperf_user_data()["id"] + results = Result.all(filters={"owner": user_id}) + results += Result.all(unregistered=True) benchmark_dset_results = [ result for result in results diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 6fbb3b08a..40b65c52e 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -62,17 +62,19 @@ def submit( @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local results"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered results" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user results"), benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), ): - """List results stored locally and remotely from the user""" + """List results""" EntityList.run( Result, fields=["UID", "Benchmark", "Model", "Dataset", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, benchmark=benchmark, ) @@ -88,8 +90,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local results if result ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered results if result ID is not provided", ), mine: bool = typer.Option( False, @@ -107,4 +111,6 @@ def view( ), ): """Displays the information of one or more results""" - EntityView.run(entity_id, Result, format, local, mine, output, benchmark=benchmark) + EntityView.run( + entity_id, Result, format, unregistered, mine, output, benchmark=benchmark + ) diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index b4c242f0a..8c2a4179f 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -14,7 +14,7 @@ def run( entity_id: Union[int, str], entity_class: Entity, format: str = "yaml", - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, output: str = None, **kwargs, @@ -24,14 +24,14 @@ def run( Args: entity_id (Union[int, str]): Entity identifies entity_class (Entity): Entity type - local_only (bool, optional): Display all local entities. Defaults to False. + unregistered (bool, optional): Display only local unregistered entities. Defaults to False. mine_only (bool, optional): Display all current-user entities. Defaults to False. format (str, optional): What format to use to display the contents. Valid formats: [yaml, json]. Defaults to yaml. output (str, optional): Path to a file for storing the entity contents. If not provided, the contents are printed. kwargs (dict): Additional parameters for filtering entity lists. """ entity_view = EntityView( - entity_id, entity_class, format, local_only, mine_only, output, **kwargs + entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ) entity_view.validate() entity_view.prepare() @@ -41,12 +41,12 @@ def run( entity_view.store() def __init__( - self, entity_id, entity_class, format, local_only, mine_only, output, **kwargs + self, entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ): self.entity_id = entity_id self.entity_class = entity_class self.format = format - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.output = output self.filters = kwargs @@ -65,7 +65,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.todict() for entity in entities] diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 849ea3fcd..1d33efa95 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,18 +1,13 @@ -import os -from medperf.exceptions import MedperfException -import yaml -import logging -from typing import List, Optional, Union +from typing import List, Optional from pydantic import HttpUrl, Field import medperf.config as config -from medperf.entities.interface import Entity, Uploadable -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -35,6 +30,26 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS user_metadata: dict = {} is_active: bool = True + @staticmethod + def get_type(): + return "benchmark" + + @staticmethod + def get_storage_path(): + return config.benchmarks_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_benchmark + + @staticmethod + def get_metadata_filename(): + return config.benchmarks_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_benchmark + def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -44,53 +59,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" - path = config.benchmarks_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: - """Gets and creates instances of all retrievable benchmarks - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Benchmark]: a list of Benchmark instances. - """ - logging.info("Retrieving all benchmarks") - benchmarks = [] - - if not local_only: - benchmarks = cls.__remote_all(filters=filters) - - remote_uids = set([bmk.id for bmk in benchmarks]) - - local_benchmarks = cls.__local_all() - - benchmarks += [bmk for bmk in local_benchmarks if bmk.id not in remote_uids] - - return benchmarks @classmethod - def __remote_all(cls, filters: dict) -> List["Benchmark"]: - benchmarks = [] - try: - comms_fn = cls.__remote_prefilter(filters) - bmks_meta = comms_fn() - benchmarks = [cls(**meta) for meta in bmks_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all benchmarks from the server" - logging.warning(msg) - - return benchmarks - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,104 +75,6 @@ def __remote_prefilter(cls, filters: dict) -> callable: comms_fn = config.comms.get_user_benchmarks return comms_fn - @classmethod - def __local_all(cls) -> List["Benchmark"]: - benchmarks = [] - bmks_storage = config.benchmarks_folder - try: - uids = next(os.walk(bmks_storage))[1] - except StopIteration: - msg = "Couldn't iterate over benchmarks directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - benchmark = cls(**meta) - benchmarks.append(benchmark) - - return benchmarks - - @classmethod - def get( - cls, benchmark_uid: Union[str, int], local_only: bool = False - ) -> "Benchmark": - """Retrieves and creates a Benchmark instance from the server. - If benchmark already exists in the platform then retrieve that - version. - - Args: - benchmark_uid (str): UID of the benchmark. - comms (Comms): Instance of a communication interface. - - Returns: - Benchmark: a Benchmark instance with the retrieved data. - """ - - if not str(benchmark_uid).isdigit() or local_only: - return cls.__local_get(benchmark_uid) - - try: - return cls.__remote_get(benchmark_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Benchmark {benchmark_uid} from comms failed") - logging.info(f"Looking for benchmark {benchmark_uid} locally") - return cls.__local_get(benchmark_uid) - - @classmethod - def __remote_get(cls, benchmark_uid: int) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") - benchmark_dict = config.comms.get_benchmark(benchmark_uid) - benchmark = cls(**benchmark_dict) - benchmark.write() - return benchmark - - @classmethod - def __local_get(cls, benchmark_uid: Union[str, int]) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} locally") - benchmark_dict = cls.__get_local_dict(benchmark_uid) - benchmark = cls(**benchmark_dict) - return benchmark - - @classmethod - def __get_local_dict(cls, benchmark_uid) -> dict: - """Retrieves a local benchmark information - - Args: - benchmark_uid (str): uid of the local benchmark - - Returns: - dict: information of the benchmark - """ - logging.info(f"Retrieving benchmark {benchmark_uid} from local storage") - storage = config.benchmarks_folder - bmk_storage = os.path.join(storage, str(benchmark_uid)) - bmk_file = os.path.join(bmk_storage, config.benchmarks_filename) - if not os.path.exists(bmk_file): - raise InvalidArgumentError("No benchmark with the given uid could be found") - with open(bmk_file, "r") as f: - data = yaml.safe_load(f) - - return data - @classmethod def get_models_uids(cls, benchmark_uid: int) -> List[int]: """Retrieves the list of models associated to the benchmark @@ -221,43 +94,6 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: ] return models_uids - def todict(self) -> dict: - """Dictionary representation of the benchmark instance - - Returns: - dict: Dictionary containing benchmark information - """ - return self.extended_dict() - - def write(self) -> str: - """Writes the benchmark into disk - - Args: - filename (str, optional): name of the file. Defaults to config.benchmarks_filename. - - Returns: - str: path to the created benchmark file - """ - data = self.todict() - bmk_file = os.path.join(self.path, config.benchmarks_filename) - if not os.path.exists(bmk_file): - os.makedirs(self.path, exist_ok=True) - with open(bmk_file, "w") as f: - yaml.dump(data, f) - return bmk_file - - def upload(self): - """Uploads a benchmark to the server - - Args: - comms (Comms): communications entity to submit through - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test benchmarks.") - body = self.todict() - updated_body = config.comms.upload_benchmark(body) - return updated_body - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index c76a50c09..b327417e2 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -1,7 +1,7 @@ import os import yaml import logging -from typing import List, Dict, Optional, Union +from typing import Dict, Optional, Union from pydantic import Field from pathlib import Path @@ -12,21 +12,15 @@ generate_tmp_path, spawn_and_kill, ) -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - ExecutionError, - InvalidEntityError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError import medperf.config as config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data -class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Cube(Entity, MedperfSchema, DeployableSchema): """ Class representing an MLCube Container @@ -48,6 +42,26 @@ class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): metadata: dict = {} user_metadata: dict = {} + @staticmethod + def get_type(): + return "cube" + + @staticmethod + def get_storage_path(): + return config.cubes_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_cube_metadata + + @staticmethod + def get_metadata_filename(): + return config.cube_metadata_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_mlcube + def __init__(self, *args, **kwargs): """Creates a Cube instance @@ -57,59 +71,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = self.name - path = config.cubes_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - # NOTE: maybe have these as @property, to have the same entity reusable - # before and after submission - self.path = path - self.cube_path = os.path.join(path, config.cube_filename) + self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: - self.params_path = os.path.join(path, config.params_filename) - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: - """Class method for retrieving all retrievable MLCubes - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Cube]: List containing all cubes - """ - logging.info("Retrieving all cubes") - cubes = [] - if not local_only: - cubes = cls.__remote_all(filters=filters) - - remote_uids = set([cube.id for cube in cubes]) - - local_cubes = cls.__local_all() - - cubes += [cube for cube in local_cubes if cube.id not in remote_uids] - - return cubes - - @classmethod - def __remote_all(cls, filters: dict) -> List["Cube"]: - cubes = [] - - try: - comms_fn = cls.__remote_prefilter(filters) - cubes_meta = comms_fn() - cubes = [cls(**meta) for meta in cubes_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all cubes from the server" - logging.warning(msg) - - return cubes + self.params_path = os.path.join(self.path, config.params_filename) @classmethod - def __remote_prefilter(cls, filters: dict): + def _Entity__remote_prefilter(cls, filters: dict): """Applies filtering logic that must be done before retrieving remote entities Args: @@ -124,25 +92,6 @@ def __remote_prefilter(cls, filters: dict): return comms_fn - @classmethod - def __local_all(cls) -> List["Cube"]: - cubes = [] - cubes_folder = config.cubes_folder - try: - uids = next(os.walk(cubes_folder))[1] - logging.debug(f"Local cubes found: {uids}") - except StopIteration: - msg = "Couldn't iterate over cubes directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - cube = cls(**meta) - cubes.append(cube) - - return cubes - @classmethod def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": """Retrieves and creates a Cube instance from the comms. If cube already exists @@ -155,36 +104,12 @@ def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": Cube : a Cube instance with the retrieved data. """ - if not str(cube_uid).isdigit() or local_only: - cube = cls.__local_get(cube_uid) - else: - try: - cube = cls.__remote_get(cube_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting MLCube {cube_uid} from comms failed") - logging.info(f"Retrieving MLCube {cube_uid} from local storage") - cube = cls.__local_get(cube_uid) - + cube = super().get(cube_uid, local_only) if not cube.is_valid: raise InvalidEntityError("The requested MLCube is marked as INVALID.") cube.download_config_files() return cube - @classmethod - def __remote_get(cls, cube_uid: int) -> "Cube": - logging.debug(f"Retrieving mlcube {cube_uid} remotely") - meta = config.comms.get_cube_metadata(cube_uid) - cube = cls(**meta) - cube.write() - return cube - - @classmethod - def __local_get(cls, cube_uid: Union[str, int]) -> "Cube": - logging.debug(f"Retrieving cube {cube_uid} locally") - local_meta = cls.__get_local_dict(cube_uid) - cube = cls(**local_meta) - return cube - def download_mlcube(self): url = self.git_mlcube_url path, file_hash = resources.get_cube(url, self.path, self.mlcube_hash) @@ -449,36 +374,6 @@ def get_config(self, identifier): return cube - def todict(self) -> Dict: - return self.extended_dict() - - def write(self): - cube_loc = str(Path(self.cube_path).parent) - meta_file = os.path.join(cube_loc, config.cube_metadata_filename) - os.makedirs(cube_loc, exist_ok=True) - with open(meta_file, "w") as f: - yaml.dump(self.todict(), f) - return meta_file - - def upload(self): - if self.for_test: - raise InvalidArgumentError("Cannot upload test mlcubes.") - cube_dict = self.todict() - updated_cube_dict = config.comms.upload_mlcube(cube_dict) - return updated_cube_dict - - @classmethod - def __get_local_dict(cls, uid): - cubes_folder = config.cubes_folder - meta_file = os.path.join(cubes_folder, str(uid), config.cube_metadata_filename) - if not os.path.exists(meta_file): - raise InvalidArgumentError( - "The requested mlcube information could not be found locally" - ) - with open(meta_file, "r") as f: - meta = yaml.safe_load(f) - return meta - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 4c210431f..f50e8d680 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,22 +1,17 @@ import os import yaml -import logging from pydantic import Field, validator -from typing import List, Optional, Union +from typing import Optional, Union from medperf.utils import remove_path -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - MedperfException, - CommunicationRetrievalError, -) + import medperf.config as config from medperf.account_management import get_medperf_user_data -class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Dataset(Entity, MedperfSchema, DeployableSchema): """ Class representing a Dataset @@ -37,6 +32,26 @@ class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): report: dict = {} submitted_as_prepared: bool + @staticmethod + def get_type(): + return "dataset" + + @staticmethod + def get_storage_path(): + return config.datasets_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_dataset + + @staticmethod + def get_metadata_filename(): + return config.reg_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_dataset + @validator("data_preparation_mlcube", pre=True, always=True) def check_data_preparation_mlcube(cls, v, *, values, **kwargs): if not isinstance(v, int) and not values["for_test"]: @@ -48,13 +63,6 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - path = config.datasets_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) @@ -86,48 +94,8 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) - def todict(self): - return self.extended_dict() - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: - """Gets and creates instances of all the locally prepared datasets - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Dataset]: a list of Dataset instances. - """ - logging.info("Retrieving all datasets") - dsets = [] - if not local_only: - dsets = cls.__remote_all(filters=filters) - - remote_uids = set([dset.id for dset in dsets]) - - local_dsets = cls.__local_all() - - dsets += [dset for dset in local_dsets if dset.id not in remote_uids] - - return dsets - - @classmethod - def __remote_all(cls, filters: dict) -> List["Dataset"]: - dsets = [] - try: - comms_fn = cls.__remote_prefilter(filters) - dsets_meta = comms_fn() - dsets = [cls(**meta) for meta in dsets_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all datasets from the server" - logging.warning(msg) - - return dsets - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -149,111 +117,6 @@ def func(): return comms_fn - @classmethod - def __local_all(cls) -> List["Dataset"]: - dsets = [] - datasets_folder = config.datasets_folder - try: - uids = next(os.walk(datasets_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - dset = cls(**local_meta) - dsets.append(dset) - - return dsets - - @classmethod - def get(cls, dset_uid: Union[str, int], local_only: bool = False) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - if not str(dset_uid).isdigit() or local_only: - return cls.__local_get(dset_uid) - - try: - return cls.__remote_get(dset_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Dataset {dset_uid} from comms failed") - logging.info(f"Looking for dataset {dset_uid} locally") - return cls.__local_get(dset_uid) - - @classmethod - def __remote_get(cls, dset_uid: int) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} remotely") - meta = config.comms.get_dataset(dset_uid) - dataset = cls(**meta) - dataset.write() - return dataset - - @classmethod - def __local_get(cls, dset_uid: Union[str, int]) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} locally") - local_meta = cls.__get_local_dict(dset_uid) - dataset = cls(**local_meta) - return dataset - - def write(self): - logging.info(f"Updating registration information for dataset: {self.id}") - logging.debug(f"registration information: {self.todict()}") - regfile = os.path.join(self.path, config.reg_file) - os.makedirs(self.path, exist_ok=True) - with open(regfile, "w") as f: - yaml.dump(self.todict(), f) - return regfile - - def upload(self): - """Uploads the registration information to the comms. - - Args: - comms (Comms): Instance of the comms interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test datasets.") - dataset_dict = self.todict() - updated_dataset_dict = config.comms.upload_dataset(dataset_dict) - return updated_dataset_dict - - @classmethod - def __get_local_dict(cls, data_uid): - dataset_path = os.path.join(config.datasets_folder, str(data_uid)) - regfile = os.path.join(dataset_path, config.reg_file) - if not os.path.exists(regfile): - raise InvalidArgumentError( - "The requested dataset information could not be found locally" - ) - with open(regfile, "r") as f: - reg = yaml.safe_load(f) - return reg - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..7a5f0b5ef 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,77 +1,215 @@ from typing import List, Dict, Union -from abc import ABC, abstractmethod +from abc import ABC +import logging +import os +import yaml +from medperf.exceptions import MedperfException, InvalidArgumentError +from medperf.entities.schemas import MedperfBaseSchema -class Entity(ABC): - @abstractmethod - def all( - cls, local_only: bool = False, comms_func: callable = None - ) -> List["Entity"]: +class Entity(MedperfBaseSchema, ABC): + @staticmethod + def get_type(): + raise NotImplementedError() + + @staticmethod + def get_storage_path(): + raise NotImplementedError() + + @staticmethod + def get_comms_retriever(): + raise NotImplementedError() + + @staticmethod + def get_metadata_filename(): + raise NotImplementedError() + + @staticmethod + def get_comms_uploader(): + raise NotImplementedError() + + @property + def identifier(self): + return self.id or self.generated_uid + + @property + def is_registered(self): + return self.id is not None + + @property + def path(self): + storage_path = self.get_storage_path() + return os.path.join(storage_path, str(self.identifier)) + + @classmethod + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: """Gets a list of all instances of the respective entity. - Wether the list is local or remote depends on the implementation. + Whether the list is local or remote depends on the implementation. Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - comms_func (callable, optional): Function to use to retrieve remote entities. - If not provided, will use the default entrypoint. + unregistered (bool, optional): Wether to retrieve only unregistered local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + Returns: List[Entity]: a list of entities. """ + logging.info(f"Retrieving all {cls.get_type()} entities") + if unregistered: + if filters: + raise InvalidArgumentError( + "Filtering is not supported for unregistered entities" + ) + return cls.__unregistered_all() + return cls.__remote_all(filters=filters) + + @classmethod + def __remote_all(cls, filters: dict) -> List["Entity"]: + comms_fn = cls.__remote_prefilter(filters) + entity_meta = comms_fn() + entities = [cls(**meta) for meta in entity_meta] + return entities + + @classmethod + def __unregistered_all(cls) -> List["Entity"]: + entities = [] + storage_path = cls.get_storage_path() + try: + uids = next(os.walk(storage_path))[1] + except StopIteration: + msg = f"Couldn't iterate over the {cls.get_type()} storage" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + if uid.isdigit(): + continue + meta = cls.__get_local_dict(uid) + entity = cls(**meta) + entities.append(entity) + + return entities + + @classmethod + def __remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + raise NotImplementedError - @abstractmethod - def get(cls, uid: Union[str, int]) -> "Entity": + @classmethod + def get(cls, uid: Union[str, int], local_only: bool = False) -> "Entity": """Gets an instance of the respective entity. Wether this requires only local read or remote calls depends on the implementation. Args: uid (str): Unique Identifier to retrieve the entity + local_only (bool): If True, the entity will be retrieved locally Returns: Entity: Entity Instance associated to the UID """ - @abstractmethod - def todict(self) -> Dict: - """Dictionary representation of the entity + if not str(uid).isdigit() or local_only: + return cls.__local_get(uid) + return cls.__remote_get(uid) + + @classmethod + def __remote_get(cls, uid: int) -> "Entity": + """Retrieves and creates an entity instance from the comms instance. + + Args: + uid (int): server UID of the entity Returns: - Dict: Dictionary containing information about the entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} remotely") + comms_func = cls.get_comms_retriever() + entity_dict = comms_func(uid) + entity = cls(**entity_dict) + entity.write() + return entity - @abstractmethod - def write(self) -> str: - """Writes the entity to the local storage + @classmethod + def __local_get(cls, uid: Union[str, int]) -> "Entity": + """Retrieves and creates an entity instance from the local storage. + + Args: + uid (str|int): UID of the entity Returns: - str: Path to the stored entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} locally") + entity_dict = cls.__get_local_dict(uid) + entity = cls(**entity_dict) + return entity - @abstractmethod - def display_dict(self) -> dict: - """Returns a dictionary of entity properties that can be displayed - to a user interface using a verbose name of the property rather than - the internal names + @classmethod + def __get_local_dict(cls, uid: Union[str, int]) -> dict: + """Retrieves a local entity information + + Args: + uid (str): uid of the local entity Returns: - dict: the display dictionary + dict: information of the entity """ + logging.info(f"Retrieving {cls.get_type()} {uid} from local storage") + storage_path = cls.get_storage_path() + metadata_filename = cls.get_metadata_filename() + bmk_file = os.path.join(storage_path, str(uid), metadata_filename) + if not os.path.exists(bmk_file): + raise InvalidArgumentError( + f"No {cls.get_type()} with the given uid could be found" + ) + with open(bmk_file, "r") as f: + data = yaml.safe_load(f) + + return data + + def write(self) -> str: + """Writes the entity to the local storage + Returns: + str: Path to the stored entity + """ + data = self.todict() + metadata_filename = self.get_metadata_filename() + entity_file = os.path.join(self.path, metadata_filename) + os.makedirs(self.path, exist_ok=True) + with open(entity_file, "w") as f: + yaml.dump(data, f) + return entity_file -class Uploadable: - @abstractmethod def upload(self) -> Dict: """Upload the entity-related information to the communication's interface Returns: Dict: Dictionary with the updated entity information """ + if self.for_test: + raise InvalidArgumentError( + f"This test {self.get_type()} cannot be uploaded." + ) + body = self.todict() + comms_func = self.get_comms_uploader() + updated_body = comms_func(body) + return updated_body - @property - def identifier(self): - return self.id or self.generated_uid + def display_dict(self) -> dict: + """Returns a dictionary of entity properties that can be displayed + to a user interface using a verbose name of the property rather than + the internal names - @property - def is_registered(self): - return self.id is not None + Returns: + dict: the display dictionary + """ + raise NotImplementedError diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index c76f09894..65147e558 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -1,16 +1,11 @@ import hashlib -import os -import yaml -import logging from typing import List, Union, Optional -from medperf.entities.schemas import MedperfBaseSchema import medperf.config as config -from medperf.exceptions import InvalidArgumentError from medperf.entities.interface import Entity -class TestReport(Entity, MedperfBaseSchema): +class TestReport(Entity): """ Class representing a compatibility test report entry @@ -35,11 +30,23 @@ class TestReport(Entity, MedperfBaseSchema): data_evaluator_mlcube: Union[int, str] results: Optional[dict] + @staticmethod + def get_type(): + return "report" + + @staticmethod + def get_storage_path(): + return config.tests_folder + + @staticmethod + def get_metadata_filename(): + return config.test_report_file + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.id = None + self.for_test = True self.generated_uid = self.__generate_uid() - path = config.tests_folder - self.path = os.path.join(path, self.generated_uid) def __generate_uid(self): """A helper that generates a unique hash for a test report.""" @@ -52,71 +59,14 @@ def set_results(self, results): self.results = results @classmethod - def all( - cls, local_only: bool = False, mine_only: bool = False - ) -> List["TestReport"]: - """Gets and creates instances of test reports. - Arguments are only specified for compatibility with - `Entity.List` and `Entity.View`, but they don't contribute to - the logic. - - Returns: - List[TestReport]: List containing all test reports - """ - logging.info("Retrieving all reports") - reports = [] - tests_folder = config.tests_folder - try: - uids = next(os.walk(tests_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the tests directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - report = cls(**local_meta) - reports.append(report) - - return reports - - @classmethod - def get(cls, report_uid: str) -> "TestReport": - """Retrieves and creates a TestReport instance obtained the user's machine - - Args: - report_uid (str): UID of the TestReport instance - - Returns: - TestReport: Specified TestReport instance - """ - logging.debug(f"Retrieving report {report_uid}") - report_dict = cls.__get_local_dict(report_uid) - report = cls(**report_dict) - report.write() - return report - - def todict(self): - return self.extended_dict() - - def write(self): - report_file = os.path.join(self.path, config.test_report_file) - os.makedirs(self.path, exist_ok=True) - with open(report_file, "w") as f: - yaml.dump(self.todict(), f) - return report_file + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["Entity"]: + assert unregistered, "Reports are only unregistered" + assert filters == {}, "Reports cannot be filtered" + return super().all(unregistered=True, filters={}) @classmethod - def __get_local_dict(cls, local_uid): - report_path = os.path.join(config.tests_folder, str(local_uid)) - report_file = os.path.join(report_path, config.test_report_file) - if not os.path.exists(report_file): - raise InvalidArgumentError( - f"The requested report {local_uid} could not be retrieved" - ) - with open(report_file, "r") as f: - report_info = yaml.safe_load(f) - return report_info + def get(cls, report_uid: str, local_only: bool = False) -> "TestReport": + return super().get(report_uid, local_only=True) def display_dict(self): if self.data_path: diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index c82add87b..af4098521 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,16 +1,10 @@ -import os -import yaml -import logging -from typing import List, Union - -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema import medperf.config as config -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.account_management import get_medperf_user_data -class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): +class Result(Entity, MedperfSchema, ApprovableSchema): """ Class representing a Result entry @@ -28,59 +22,34 @@ class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): metadata: dict = {} user_metadata: dict = {} - def __init__(self, *args, **kwargs): - """Creates a new result instance""" - super().__init__(*args, **kwargs) - - self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" - path = config.results_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: - """Gets and creates instances of all the user's results - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Result]: List containing all results - """ - logging.info("Retrieving all results") - results = [] - if not local_only: - results = cls.__remote_all(filters=filters) - - remote_uids = set([result.id for result in results]) + @staticmethod + def get_type(): + return "result" - local_results = cls.__local_all() + @staticmethod + def get_storage_path(): + return config.results_folder - results += [res for res in local_results if res.id not in remote_uids] + @staticmethod + def get_comms_retriever(): + return config.comms.get_result - return results + @staticmethod + def get_metadata_filename(): + return config.results_info_file - @classmethod - def __remote_all(cls, filters: dict) -> List["Result"]: - results = [] + @staticmethod + def get_comms_uploader(): + return config.comms.upload_result - try: - comms_fn = cls.__remote_prefilter(filters) - results_meta = comms_fn() - results = [cls(**meta) for meta in results_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all results from the server" - logging.warning(msg) + def __init__(self, *args, **kwargs): + """Creates a new result instance""" + super().__init__(*args, **kwargs) - return results + self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,113 +73,6 @@ def get_benchmark_results(): return comms_fn - @classmethod - def __local_all(cls) -> List["Result"]: - results = [] - results_folder = config.results_folder - try: - uids = next(os.walk(results_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - result = cls(**local_meta) - results.append(result) - - return results - - @classmethod - def get(cls, result_uid: Union[str, int], local_only: bool = False) -> "Result": - """Retrieves and creates a Result instance obtained from the platform. - If the result instance already exists in the user's machine, it loads - the local instance - - Args: - result_uid (str): UID of the Result instance - - Returns: - Result: Specified Result instance - """ - if not str(result_uid).isdigit() or local_only: - return cls.__local_get(result_uid) - - try: - return cls.__remote_get(result_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Result {result_uid} from comms failed") - logging.info(f"Looking for result {result_uid} locally") - return cls.__local_get(result_uid) - - @classmethod - def __remote_get(cls, result_uid: int) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} remotely") - meta = config.comms.get_result(result_uid) - result = cls(**meta) - result.write() - return result - - @classmethod - def __local_get(cls, result_uid: Union[str, int]) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} locally") - local_meta = cls.__get_local_dict(result_uid) - result = cls(**local_meta) - return result - - def todict(self): - return self.extended_dict() - - def upload(self): - """Uploads the results to the comms - - Args: - comms (Comms): Instance of the communications interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test results.") - results_info = self.todict() - updated_results_info = config.comms.upload_result(results_info) - return updated_results_info - - def write(self): - result_file = os.path.join(self.path, config.results_info_file) - os.makedirs(self.path, exist_ok=True) - with open(result_file, "w") as f: - yaml.dump(self.todict(), f) - return result_file - - @classmethod - def __get_local_dict(cls, local_uid): - result_path = os.path.join(config.results_folder, str(local_uid)) - result_file = os.path.join(result_path, config.results_info_file) - if not os.path.exists(result_file): - raise InvalidArgumentError( - f"The requested result {local_uid} could not be retrieved" - ) - with open(result_file, "r") as f: - results_info = yaml.safe_load(f) - return results_info - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index 0e7a54291..cac3d3a01 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -46,7 +46,7 @@ def dict(self, *args, **kwargs) -> dict: out_dict = {k: v for k, v in model_dict.items() if k in valid_fields} return out_dict - def extended_dict(self) -> dict: + def todict(self) -> dict: """Dictionary containing both original and alias fields Returns: @@ -74,7 +74,7 @@ class Config: use_enum_values = True -class MedperfSchema(MedperfBaseSchema): +class MedperfSchema(BaseModel): for_test: bool = False id: Optional[int] name: str = Field(..., max_length=64) diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 74299c77e..c69544781 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -57,6 +57,9 @@ def mock_result_all(mocker, state_variables): TestResult(benchmark=triplet[0], model=triplet[1], dataset=triplet[2]) for triplet in cached_results_triplets ] + mocker.patch( + PATCH_EXECUTION.format("get_medperf_user_data", return_value={"id": 1}) + ) mocker.patch(PATCH_EXECUTION.format("Result.all"), return_value=results) diff --git a/cli/medperf/tests/commands/test_list.py b/cli/medperf/tests/commands/test_list.py index 1c2dc3267..ce7035960 100644 --- a/cli/medperf/tests/commands/test_list.py +++ b/cli/medperf/tests/commands/test_list.py @@ -47,18 +47,18 @@ def set_common_attributes(self, setup): self.state_variables = state_variables self.spies = spies - @pytest.mark.parametrize("local_only", [False, True]) + @pytest.mark.parametrize("unregistered", [False, True]) @pytest.mark.parametrize("mine_only", [False, True]) - def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): + def test_entity_all_is_called_properly(self, mocker, unregistered, mine_only): # Arrange filters = {"owner": 1} if mine_only else {} # Act - EntityList.run(Entity, [], local_only, mine_only) + EntityList.run(Entity, [], unregistered, mine_only) # Assert self.spies["all"].assert_called_once_with( - local_only=local_only, filters=filters + unregistered=unregistered, filters=filters ) @pytest.mark.parametrize("fields", [["UID", "MLCube"]]) diff --git a/cli/medperf/tests/commands/test_view.py b/cli/medperf/tests/commands/test_view.py index a2dddfeda..0ffe0fb13 100644 --- a/cli/medperf/tests/commands/test_view.py +++ b/cli/medperf/tests/commands/test_view.py @@ -1,143 +1,86 @@ import pytest -import yaml -import json from medperf.entities.interface import Entity -from medperf.exceptions import InvalidArgumentError from medperf.commands.view import EntityView - -def expected_output(entities, format): - if isinstance(entities, list): - data = [entity.todict() for entity in entities] - else: - data = entities.todict() - - if format == "yaml": - return yaml.dump(data) - if format == "json": - return json.dumps(data) - - -def generate_entity(id, mocker): - entity = mocker.create_autospec(spec=Entity) - mocker.patch.object(entity, "todict", return_value={"id": id}) - return entity +PATCH_VIEW = "medperf.commands.view.{}" @pytest.fixture -def ui_spy(mocker, ui): - return mocker.patch.object(ui, "print") +def entity(mocker): + return mocker.create_autospec(Entity) -@pytest.fixture( - params=[{"local": ["1", "2", "3"], "remote": ["4", "5", "6"], "user": ["4"]}] -) -def setup(request, mocker): - local_ids = request.param.get("local", []) - remote_ids = request.param.get("remote", []) - user_ids = request.param.get("user", []) - all_ids = list(set(local_ids + remote_ids + user_ids)) - - local_entities = [generate_entity(id, mocker) for id in local_ids] - remote_entities = [generate_entity(id, mocker) for id in remote_ids] - user_entities = [generate_entity(id, mocker) for id in user_ids] - all_entities = list(set(local_entities + remote_entities + user_entities)) - - def mock_all(filters={}, local_only=False): - if "owner" in filters: - return user_entities - if local_only: - return local_entities - return all_entities - - def mock_get(entity_id): - if entity_id in all_ids: - return generate_entity(entity_id, mocker) - else: - raise InvalidArgumentError - - mocker.patch("medperf.commands.view.get_medperf_user_data", return_value={"id": 1}) - mocker.patch.object(Entity, "all", side_effect=mock_all) - mocker.patch.object(Entity, "get", side_effect=mock_get) - - return local_entities, remote_entities, user_entities, all_entities - - -class TestViewEntityID: - def test_view_displays_entity_if_given(self, mocker, setup, ui_spy): - # Arrange - entity_id = "1" - entity = generate_entity(entity_id, mocker) - output = expected_output(entity, "yaml") - - # Act - EntityView.run(entity_id, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_all_if_no_id(self, setup, ui_spy): - # Arrange - *_, entities = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - -class TestViewFilteredEntities: - def test_view_displays_local_entities(self, setup, ui_spy): - # Arrange - entities, *_ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, local_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_user_entities(self, setup, ui_spy): - # Arrange - *_, entities, _ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, mine_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - -@pytest.mark.parametrize("entity_id", ["4", None]) -@pytest.mark.parametrize("format", ["yaml", "json"]) -class TestViewOutput: - @pytest.fixture - def output(self, setup, mocker, entity_id, format): - if entity_id is None: - *_, entities = setup - return expected_output(entities, format) - else: - entity = generate_entity(entity_id, mocker) - return expected_output(entity, format) - - def test_view_displays_specified_format(self, entity_id, output, ui_spy, format): - # Act - EntityView.run(entity_id, Entity, format=format) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_stores_specified_format(self, entity_id, output, format, fs): - # Arrange - filename = "file.txt" - - # Act - EntityView.run(entity_id, Entity, format=format, output=filename) - - # Assert - contents = open(filename, "r").read() - assert contents == output +@pytest.fixture +def entity_view(mocker): + view_class = EntityView(None, Entity, "", "", "", "") + return view_class + + +def test_prepare_with_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = 1 + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + get_spy.assert_called_once_with(1) + all_spy.assert_not_called() + assert not isinstance(entity_view.data, list) + + +def test_prepare_with_no_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once() + get_spy.assert_not_called() + assert isinstance(entity_view.data, list) + + +@pytest.mark.parametrize("unregistered", [False, True]) +def test_prepare_with_no_id_calls_all_with_unregistered_properly( + mocker, entity_view, entity, unregistered +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = unregistered + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=unregistered, filters={}) + + +@pytest.mark.parametrize("filters", [{}, {"f1": "v1"}]) +@pytest.mark.parametrize("mine_only", [False, True]) +def test_prepare_with_no_id_calls_all_with_proper_filters( + mocker, entity_view, entity, filters, mine_only +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = False + entity_view.filters = filters + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + mocker.patch(PATCH_VIEW.format("get_medperf_user_data"), return_value={"id": 1}) + if mine_only: + filters["owner"] = 1 + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=False, filters=filters) diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 3f1fde2e2..c36771e12 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -9,8 +9,9 @@ @pytest.fixture( params={ - "local": [1, 2, 3], - "remote": [4, 5, 6], + "unregistered": ["b1", "b2"], + "local": ["b1", "b2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], "user": [4], "models": [10, 11], } diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index b82b9a0e8..51234f6e3 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -24,7 +24,14 @@ } -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["c1", "c2"], + "local": ["c1", "c2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, fs): local_ents = request.param.get("local", []) remote_ents = request.param.get("remote", []) @@ -282,7 +289,9 @@ def test_run_stops_execution_if_child_fails(self, mocker, setup, task): cube.run(task) -@pytest.mark.parametrize("setup", [{"local": [DEFAULT_CUBE]}], indirect=True) +@pytest.mark.parametrize( + "setup", [{"local": [DEFAULT_CUBE], "remote": [DEFAULT_CUBE]}], indirect=True +) @pytest.mark.parametrize("task", ["task"]) @pytest.mark.parametrize( "out_key,out_value", diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index c636b2c26..b9d309f39 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -15,7 +15,7 @@ setup_result_fs, setup_result_comms, ) -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError @pytest.fixture(params=[Benchmark, Cube, Dataset, Result]) @@ -23,7 +23,14 @@ def Implementation(request): return request.param -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, Implementation, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) @@ -54,39 +61,52 @@ def setup(request, mocker, comms, Implementation, fs): @pytest.mark.parametrize( "setup", - [{"local": [283, 17, 493], "remote": [283, 1, 2], "user": [2]}], + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 283], + "remote": [283, 1, 2], + "user": [2], + } + ], indirect=True, ) class TestAll: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): self.ids = setup + self.unregistered_ids = set(self.ids["unregistered"]) self.local_ids = set(self.ids["local"]) self.remote_ids = set(self.ids["remote"]) self.user_ids = set(self.ids["user"]) - def test_all_returns_all_remote_and_local(self, Implementation): - # Arrange - all_ids = self.local_ids.union(self.remote_ids) - + def test_all_returns_all_remote_by_default(self, Implementation): # Act entities = Implementation.all() # Assert retrieved_ids = set([e.todict()["id"] for e in entities]) - assert all_ids == retrieved_ids + assert self.remote_ids == retrieved_ids - def test_all_local_only_returns_all_local(self, Implementation): + def test_all_unregistered_returns_all_unregistered(self, Implementation): # Act - entities = Implementation.all(local_only=True) + entities = Implementation.all(unregistered=True) # Assert - retrieved_ids = set([e.todict()["id"] for e in entities]) - assert self.local_ids == retrieved_ids + retrieved_names = set([e.name for e in entities]) + assert self.unregistered_ids == retrieved_names @pytest.mark.parametrize( - "setup", [{"local": [78], "remote": [479, 42, 7, 1]}], indirect=True + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 479], + "remote": [479, 42, 7, 1], + } + ], + indirect=True, ) class TestGet: def test_get_retrieves_entity_from_server(self, Implementation, setup): @@ -99,30 +119,20 @@ def test_get_retrieves_entity_from_server(self, Implementation, setup): # Assert assert entity.todict()["id"] == id - def test_get_retrieves_entity_local_if_not_on_server(self, Implementation, setup): - # Arrange - id = setup["local"][0] - - # Act - entity = Implementation.get(id) - - # Assert - assert entity.todict()["id"] == id - def test_get_raises_error_if_nonexistent(self, Implementation, setup): # Arrange id = str(19283) # Act & Assert - with pytest.raises(InvalidArgumentError): + with pytest.raises(CommunicationRetrievalError): Implementation.get(id) -@pytest.mark.parametrize("setup", [{"local": [742]}], indirect=True) +@pytest.mark.parametrize("setup", [{"remote": [742]}], indirect=True) class TestToDict: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): - self.id = setup["local"][0] + self.id = setup["remote"][0] def test_todict_returns_dict_representation(self, Implementation): # Arrange @@ -147,7 +157,16 @@ def test_todict_can_recreate_object(self, Implementation): assert ent_dict == ent_copy_dict -@pytest.mark.parametrize("setup", [{"local": [36]}], indirect=True) +@pytest.mark.parametrize( + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2"], + } + ], + indirect=True, +) class TestUpload: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 522251ca7..19c3178e3 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -15,14 +15,17 @@ # Setup Benchmark def setup_benchmark_fs(ents, fs): - bmks_path = config.benchmarks_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) - bmk_contents = TestBenchmark(**ent) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + bmk_contents = TestBenchmark(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + bmk_contents = TestBenchmark(id=str(ent)) + else: + bmk_contents = TestBenchmark(id=None, name=ent) + bmk_contents.generated_uid = ent + + bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) @@ -30,7 +33,7 @@ def setup_benchmark_fs(ents, fs): cubes_ids = list(set(cubes_ids)) setup_cube_fs(cubes_ids, fs) try: - fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.dict())) + fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.todict())) except FileExistsError: pass @@ -51,17 +54,18 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Cube def setup_cube_fs(ents, fs): - cubes_path = config.cubes_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - meta_cube_file = os.path.join( - cubes_path, str(id), config.cube_metadata_filename - ) - cube = TestCube(**ent) - meta = cube.dict() + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + cube = TestCube(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + cube = TestCube(id=str(ent)) + else: + cube = TestCube(id=None, name=ent) + cube.generated_uid = ent + + meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) + meta = cube.todict() try: fs.create_file(meta_cube_file, contents=yaml.dump(meta)) except FileExistsError: @@ -124,18 +128,21 @@ def setup_cube_comms_downloads(mocker, fs): # Setup Dataset def setup_dset_fs(ents, fs): - dsets_path = config.datasets_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - reg_dset_file = os.path.join(dsets_path, str(id), config.reg_file) - dset_contents = TestDataset(**ent) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + dset_contents = TestDataset(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + dset_contents = TestDataset(id=str(ent)) + else: + dset_contents = TestDataset(id=None, name=ent) + dset_contents.generated_uid = ent + + reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube setup_cube_fs([cube_id], fs) try: - fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.dict())) + fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.todict())) except FileExistsError: pass @@ -155,22 +162,26 @@ def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Result def setup_result_fs(ents, fs): - results_path = config.results_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - result_file = os.path.join(results_path, str(id), config.results_info_file) - bmk_id = ent.get("benchmark", 1) - cube_id = ent.get("model", 1) - dataset_id = ent.get("dataset", 1) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + result_contents = TestResult(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + result_contents = TestResult(id=str(ent)) + else: + result_contents = TestResult(id=None, name=ent) + result_contents.generated_uid = ent + + result_file = os.path.join(result_contents.path, config.results_info_file) + bmk_id = result_contents.benchmark + cube_id = result_contents.model + dataset_id = result_contents.dataset setup_benchmark_fs([bmk_id], fs) setup_cube_fs([cube_id], fs) setup_dset_fs([dataset_id], fs) - result_contents = TestResult(**ent) + try: - fs.create_file(result_file, contents=yaml.dump(result_contents.dict())) + fs.create_file(result_file, contents=yaml.dump(result_contents.todict())) except FileExistsError: pass From 89f256bc68c0cc17998bc07ff0629096d7f021a1 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Apr 2024 06:27:26 +0200 Subject: [PATCH 044/242] client updates --- cli/medperf/certificates.py | 47 + cli/medperf/cli.py | 8 +- cli/medperf/commands/aggregator/aggregator.py | 41 +- cli/medperf/commands/aggregator/associate.py | 10 +- cli/medperf/commands/aggregator/run.py | 103 +- cli/medperf/commands/aggregator/submit.py | 39 +- cli/medperf/commands/association/approval.py | 50 +- .../commands/association/association.py | 57 +- cli/medperf/commands/association/list.py | 58 +- cli/medperf/commands/association/priority.py | 4 +- cli/medperf/commands/association/utils.py | 119 ++ cli/medperf/commands/ca/associate.py | 34 + cli/medperf/commands/ca/ca.py | 102 ++ cli/medperf/commands/ca/submit.py | 65 ++ .../commands/certificate/certificate.py | 38 + .../certificate/client_certificate.py | 18 + .../certificate/server_certificate.py | 19 + cli/medperf/commands/dataset/associate.py | 69 +- .../commands/dataset/associate_benchmark.py | 52 + .../commands/dataset/associate_training.py | 39 + cli/medperf/commands/dataset/dataset.py | 15 +- cli/medperf/commands/mlcube/associate.py | 2 +- cli/medperf/commands/result/create.py | 2 + cli/medperf/commands/training/approve.py | 38 - cli/medperf/commands/training/associate.py | 45 - cli/medperf/commands/training/close_event.py | 59 + cli/medperf/commands/training/list_assocs.py | 41 - cli/medperf/commands/training/lock.py | 23 - cli/medperf/commands/training/run.py | 94 +- cli/medperf/commands/training/set_plan.py | 86 ++ cli/medperf/commands/training/start_event.py | 58 + cli/medperf/commands/training/submit.py | 6 +- cli/medperf/commands/training/training.py | 86 +- cli/medperf/comms/rest.py | 1000 +++++++++-------- cli/medperf/config.py | 13 +- cli/medperf/cryptography/__init__.py | 3 - cli/medperf/cryptography/ca.py | 150 --- cli/medperf/cryptography/io.py | 129 --- cli/medperf/cryptography/participant.py | 72 -- cli/medperf/cryptography/utils.py | 14 - cli/medperf/entities/aggregator.py | 227 +--- cli/medperf/entities/benchmark.py | 11 +- cli/medperf/entities/ca.py | 115 ++ cli/medperf/entities/event.py | 100 ++ cli/medperf/entities/training_exp.py | 282 +---- .../commands/association/test_approve.py | 4 +- .../commands/association/test_priority.py | 2 +- .../commands/benchmark/test_associate.py | 4 +- .../tests/commands/dataset/test_associate.py | 6 +- .../tests/commands/mlcube/test_associate.py | 4 +- cli/medperf/tests/comms/test_rest.py | 22 +- cli/medperf/utils.py | 77 +- examples/fl/cert/mlcube/mlcube.yaml | 38 + .../fl/cert/mlcube/workspace/ca_config.json | 7 + examples/fl/cert/project/Dockerfile | 7 + examples/fl/cert/project/get_cert.sh | 62 + examples/fl/cert/project/trust.sh | 42 + server/aggregator/models.py | 6 +- server/ca/models.py | 10 +- server/dataset/serializers.py | 12 + server/dataset/urls.py | 6 +- server/medperf/urls.py | 1 + server/training/urls.py | 8 +- server/training/views.py | 66 ++ server/trainingevent/models.py | 7 + server/trainingevent/serializers.py | 10 +- server/trainingevent/urls.py | 9 + server/trainingevent/views.py | 36 +- 68 files changed, 2222 insertions(+), 1867 deletions(-) create mode 100644 cli/medperf/certificates.py create mode 100644 cli/medperf/commands/association/utils.py create mode 100644 cli/medperf/commands/ca/associate.py create mode 100644 cli/medperf/commands/ca/ca.py create mode 100644 cli/medperf/commands/ca/submit.py create mode 100644 cli/medperf/commands/certificate/certificate.py create mode 100644 cli/medperf/commands/certificate/client_certificate.py create mode 100644 cli/medperf/commands/certificate/server_certificate.py create mode 100644 cli/medperf/commands/dataset/associate_benchmark.py create mode 100644 cli/medperf/commands/dataset/associate_training.py delete mode 100644 cli/medperf/commands/training/approve.py delete mode 100644 cli/medperf/commands/training/associate.py create mode 100644 cli/medperf/commands/training/close_event.py delete mode 100644 cli/medperf/commands/training/list_assocs.py delete mode 100644 cli/medperf/commands/training/lock.py create mode 100644 cli/medperf/commands/training/set_plan.py create mode 100644 cli/medperf/commands/training/start_event.py delete mode 100644 cli/medperf/cryptography/__init__.py delete mode 100644 cli/medperf/cryptography/ca.py delete mode 100644 cli/medperf/cryptography/io.py delete mode 100644 cli/medperf/cryptography/participant.py delete mode 100644 cli/medperf/cryptography/utils.py create mode 100644 cli/medperf/entities/ca.py create mode 100644 cli/medperf/entities/event.py create mode 100644 examples/fl/cert/mlcube/mlcube.yaml create mode 100644 examples/fl/cert/mlcube/workspace/ca_config.json create mode 100644 examples/fl/cert/project/Dockerfile create mode 100644 examples/fl/cert/project/get_cert.sh create mode 100644 examples/fl/cert/project/trust.sh create mode 100644 server/trainingevent/urls.py diff --git a/cli/medperf/certificates.py b/cli/medperf/certificates.py new file mode 100644 index 000000000..dfb9afc9c --- /dev/null +++ b/cli/medperf/certificates.py @@ -0,0 +1,47 @@ +from medperf.entities.ca import CA +from medperf.entities.cube import Cube + + +def get_client_cert(ca: CA, email: str, output_path: str): + """Responsible for getting a user cert""" + common_name = email + ca.prepare_config() + params = { + "ca_config": ca.config_path, + "pki_assets": output_path, + } + env = {"MEDPERF_INPUT_CN": common_name} + + mlcube = Cube.get(ca.client_mlcube) + mlube_task = "get_client_cert" + mlcube.run(task=mlube_task, env_dict=env, **params) + + +def get_server_cert(ca: CA, address: str, output_path: str): + """Responsible for getting a server cert""" + common_name = address + ca.prepare_config() + params = { + "ca_config": ca.config_path, + "pki_assets": output_path, + } + env = {"MEDPERF_INPUT_CN": common_name} + + mlcube = Cube.get(ca.server_mlcube) + mlube_task = "get_server_cert" + mlcube.run(task=mlube_task, env_dict=env, port=80, **params) + + +def trust(ca: CA): + """Verifies the CA cert fingerprint and writes it to the MedPerf storage. + This is needed when running a workload, either by the users or + by the aggregator + """ + ca.prepare_config() + params = { + "ca_config": ca.config_path, + "pki_assets": ca.pki_assets, + } + mlcube = Cube.get(ca.ca_mlcube) + mlube_task = "trust" + mlcube.run(task=mlube_task, **params) diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index b58ff7cb4..242d9b734 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -18,6 +18,8 @@ import medperf.commands.compatibility_test.compatibility_test as compatibility_test import medperf.commands.training.training as training import medperf.commands.aggregator.aggregator as aggregator +import medperf.commands.ca.ca as ca +import medperf.commands.certificate.certificate as certificate import medperf.commands.storage as storage # from medperf.utils import check_for_updates @@ -33,8 +35,10 @@ app.add_typer(compatibility_test.app, name="test", help="Manage compatibility tests") app.add_typer(auth.app, name="auth", help="Authentication") app.add_typer(storage.app, name="storage", help="Storage management") -app.add_typer(training.app, name="training", help="Training") -app.add_typer(aggregator.app, name="aggregator", help="Aggregator") +app.add_typer(training.app, name="training", help="Manage training experiments") +app.add_typer(aggregator.app, name="aggregator", help="Manage aggregators") +app.add_typer(ca.app, name="ca", help="Manage CAs") +app.add_typer(certificate.app, name="certificate", help="Manage certificates") @app.command("run") diff --git a/cli/medperf/commands/aggregator/aggregator.py b/cli/medperf/commands/aggregator/aggregator.py index 4bf5925f2..f640d8f38 100644 --- a/cli/medperf/commands/aggregator/aggregator.py +++ b/cli/medperf/commands/aggregator/aggregator.py @@ -17,16 +17,19 @@ @app.command("submit") @clean_except def submit( - name: str = typer.Option(..., "--name", "-n", help="Name of the agg"), + name: str = typer.Option(..., "--name", "-n", help="Name of the aggregator"), address: str = typer.Option( - ..., "--address", "-a", help="UID of benchmark to associate with" + ..., "--address", "-a", help="Address/domain of the aggregator" ), port: int = typer.Option( - ..., "--port", "-p", help="UID of benchmark to associate with" + ..., "--port", "-p", help="The port which the aggregator will use" + ), + aggregation_mlcube: int = typer.Option( + ..., "--aggregation-mlcube", "-m", help="Aggregation MLCube UID" ), ): - """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" - SubmitAggregator.run(name, address, port) + """Submits an aggregator""" + SubmitAggregator.run(name, address, port, aggregation_mlcube) config.ui.print("✅ Done!") @@ -49,29 +52,31 @@ def associate( @app.command("start") @clean_except def run( - aggregator_id: int = typer.Option( - ..., "--aggregator_id", "-a", help="UID of benchmark to associate with" - ), training_exp_id: int = typer.Option( - ..., "--training_exp_id", "-t", help="UID of benchmark to associate with" + ..., + "--training_exp_id", + "-t", + help="UID of training experiment whose aggregator to be run", ), ): - """Associates a benchmark with a given mlcube or dataset. Only one option at a time.""" - StartAggregator.run(training_exp_id, aggregator_id) + """Starts the aggregation server of a training experiment""" + StartAggregator.run(training_exp_id) config.ui.print("✅ Done!") @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local aggregators"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered aggregators" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user aggregators"), ): - """List aggregators stored locally and remotely from the user""" + """List aggregators""" EntityList.run( Aggregator, fields=["UID", "Name", "Address", "Port"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -86,10 +91,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( + unregistered: bool = typer.Option( False, - "--local", - help="Display local benchmarks if benchmark ID is not provided", + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", ), mine: bool = typer.Option( False, @@ -104,4 +109,4 @@ def view( ), ): """Displays the information of one or more aggregators""" - EntityView.run(entity_id, Aggregator, format, local, mine, output) + EntityView.run(entity_id, Aggregator, format, unregistered, mine, output) diff --git a/cli/medperf/commands/aggregator/associate.py b/cli/medperf/commands/aggregator/associate.py index 0e21a6b66..0222c3aef 100644 --- a/cli/medperf/commands/aggregator/associate.py +++ b/cli/medperf/commands/aggregator/associate.py @@ -1,14 +1,14 @@ from medperf import config from medperf.entities.aggregator import Aggregator from medperf.entities.training_exp import TrainingExp -from medperf.utils import approval_prompt, generate_agg_csr +from medperf.utils import approval_prompt from medperf.exceptions import InvalidArgumentError class AssociateAggregator: @staticmethod def run(training_exp_id: int, agg_uid: int, approved=False): - """Associates a registered aggregator with a benchmark + """Associates an aggregator with a training experiment Args: agg_uid (int): UID of the registered aggregator to associate @@ -22,17 +22,13 @@ def run(training_exp_id: int, agg_uid: int, approved=False): raise InvalidArgumentError(msg) training_exp = TrainingExp.get(training_exp_id) - csr, csr_hash = generate_agg_csr(training_exp_id, agg.address, agg.id) msg = "Please confirm that you would like to associate" msg += f" the aggregator {agg.name} with the training exp {training_exp.name}." - msg += f" The certificate signing request hash is: {csr_hash}" msg += " [Y/n]" approved = approved or approval_prompt(msg) if approved: ui.print("Generating aggregator training association") - # TODO: delete keys if upload fails - # check if on failure, other (possible) request will overwrite key - comms.associate_aggregator(agg.id, training_exp_id, csr) + comms.associate_training_aggregator(agg.id, training_exp_id) else: ui.print("Aggregator association operation cancelled.") diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index a38e99155..edf129cda 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -1,79 +1,49 @@ -import os from medperf import config +from medperf.entities.ca import CA +from medperf.entities.event import TrainingEvent from medperf.exceptions import InvalidArgumentError from medperf.entities.training_exp import TrainingExp from medperf.entities.aggregator import Aggregator from medperf.entities.cube import Cube +from medperf.utils import get_pki_assets_path +from medperf.certificates import trust class StartAggregator: @classmethod - def run(cls, training_exp_id: int, agg_uid: int): - """Sets approval status for an association between a benchmark and a aggregator or mlcube + def run(cls, training_exp_id: int): + """Starts the aggregation server of a training experiment Args: - benchmark_uid (int): Benchmark UID. - approval_status (str): Desired approval status to set for the association. - comms (Comms): Instance of Comms interface. - ui (UI): Instance of UI interface. - aggregator_uid (int, optional): Aggregator UID. Defaults to None. - mlcube_uid (int, optional): MLCube UID. Defaults to None. + training_exp_id (int): Training experiment UID. """ - execution = cls(training_exp_id, agg_uid) + execution = cls(training_exp_id) execution.prepare() execution.validate() - execution.prepare_agg_cert() - execution.prepare_cube() + execution.prepare_aggregator() + execution.prepare_participants_list() + execution.prepare_plan() + execution.prepare_pki_assets() with config.ui.interactive(): execution.run_experiment() - def __init__(self, training_exp_id, agg_uid) -> None: + def __init__(self, training_exp_id) -> None: self.training_exp_id = training_exp_id - self.agg_uid = agg_uid self.ui = config.ui def prepare(self): self.training_exp = TrainingExp.get(self.training_exp_id) self.ui.print(f"Training Execution: {self.training_exp.name}") - self.aggregator = Aggregator.get(self.agg_uid) + self.event = TrainingEvent.from_experiment(self.training_exp_id) def validate(self): - if self.aggregator.id is None: - msg = "The provided aggregator is not registered." + if self.event.finished(): + msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) - training_exp_aggregator = config.comms.get_experiment_aggregator( - self.training_exp.id - ) - - if self.aggregator.id != training_exp_aggregator["id"]: - msg = "The provided aggregator is not associated." - raise InvalidArgumentError(msg) - - if self.training_exp.state != "OPERATION": - msg = "The provided training exp is not operational." - raise InvalidArgumentError(msg) - - def prepare_agg_cert(self): - association = config.comms.get_aggregator_association( - self.training_exp.id, self.aggregator.id - ) - cert = association["certificate"] - cert_folder = os.path.join( - config.training_folder, - str(self.training_exp.id), - config.agg_cert_folder, - str(self.aggregator.id), - ) - os.makedirs(cert_folder, exist_ok=True) - cert_file = os.path.join(cert_folder, "cert.crt") - with open(cert_file, "w") as f: - f.write(cert) - - self.agg_cert_path = cert_folder - - def prepare_cube(self): - self.cube = self.__get_cube(self.training_exp.fl_mlcube, "training") + def prepare_aggregator(self): + self.aggregator = Aggregator.from_experiment(self.training_exp_id) + self.cube = self.__get_cube(self.aggregator.aggregation_mlcube, "aggregation") def __get_cube(self, uid: int, name: str) -> Cube: self.ui.text = f"Retrieving {name} cube" @@ -82,22 +52,29 @@ def __get_cube(self, uid: int, name: str) -> Cube: self.ui.print(f"> {name} cube download complete") return cube - def run_experiment(self): - task = "start_aggregator" - # just for now create some output folders (TODO) - out_logs = os.path.join(self.training_exp.path, "logs") - out_weights = os.path.join(self.training_exp.path, "weights") - os.makedirs(out_logs, exist_ok=True) - os.makedirs(out_weights, exist_ok=True) + def prepare_participants_list(self): + self.event.prepare_participants_list() + + def prepare_plan(self): + self.training_exp.prepare_plan() + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + agg_address = self.aggregator.address + self.aggregator_pki_assets = get_pki_assets_path(agg_address, ca.name) + self.ca = ca + + def run_experiment(self): params = { - "node_cert_folder": self.agg_cert_path, - "ca_cert_folder": self.training_exp.cert_path, - "network_config": self.aggregator.network_config_path, - "collaborators": self.training_exp.cols_path, - "output_logs": out_logs, - "output_weights": out_weights, + "node_cert_folder": self.aggregator_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "collaborators": self.event.participants_list_path, + "output_logs": self.event.out_logs, + "output_weights": self.event.out_weights, + "report_path": self.event.report_path, } self.ui.text = "Running Aggregator" - self.cube.run(task=task, port=self.aggregator.port, **params) + self.cube.run(task="start_aggregator", port=self.aggregator.port, **params) diff --git a/cli/medperf/commands/aggregator/submit.py b/cli/medperf/commands/aggregator/submit.py index 53335cd63..3fd653fcb 100644 --- a/cli/medperf/commands/aggregator/submit.py +++ b/cli/medperf/commands/aggregator/submit.py @@ -1,45 +1,40 @@ import medperf.config as config from medperf.entities.aggregator import Aggregator from medperf.utils import remove_path +from medperf.entities.cube import Cube class SubmitAggregator: @classmethod - def run(cls, name, address, port): - """Submits a new cube to the medperf platform + def run(cls, name: str, address: str, port: int, aggregation_mlcube: int): + """Submits a new aggregator to the medperf platform Args: - benchmark_info (dict): benchmark information - expected keys: - name (str): benchmark name - description (str): benchmark description - docs_url (str): benchmark documentation url - demo_url (str): benchmark demo dataset url - demo_hash (str): benchmark demo dataset hash - data_preparation_mlcube (int): benchmark data preparation mlcube uid - reference_model_mlcube (int): benchmark reference model mlcube uid - evaluator_mlcube (int): benchmark data evaluator mlcube uid + name (str): aggregator name + address (str): aggregator address/domain + port (int): port which the aggregator will use + aggregation_mlcube (int): aggregation mlcube uid """ ui = config.ui - submission = cls(name, address, port) + submission = cls(name, address, port, aggregation_mlcube) with ui.interactive(): ui.text = "Submitting Aggregator to MedPerf" + submission.validate_agg_cube() updated_benchmark_body = submission.submit() ui.print("Uploaded") submission.write(updated_benchmark_body) - def __init__(self, name, address, port): + def __init__(self, name: str, address: str, port: int, aggregation_mlcube: int): self.ui = config.ui - # TODO: server config should be a URL... - server_config = { - "address": address, - "agg_addr": address, - "port": port, - "agg_port": port, - } - self.aggregator = Aggregator(name=name, server_config=server_config) + agg_config = {"address": address, "port": port} + self.aggregator = Aggregator( + name=name, config=agg_config, aggregation_mlcube=aggregation_mlcube + ) config.tmp_paths.append(self.aggregator.path) + def validate_agg_cube(self): + Cube.get(self.aggregator.aggregation_mlcube) + def submit(self): updated_body = self.aggregator.upload() return updated_body diff --git a/cli/medperf/commands/association/approval.py b/cli/medperf/commands/association/approval.py index 4ed343911..df486b420 100644 --- a/cli/medperf/commands/association/approval.py +++ b/cli/medperf/commands/association/approval.py @@ -1,14 +1,17 @@ from medperf import config -from medperf.exceptions import InvalidArgumentError +from medperf.commands.association.utils import validate_args class Approval: @staticmethod def run( - benchmark_uid: int, approval_status: str, + benchmark_uid: int = None, + training_exp_uid: int = None, dataset_uid: int = None, mlcube_uid: int = None, + aggregator_uid: int = None, + ca_uid: int = None, ): """Sets approval status for an association between a benchmark and a dataset or mlcube @@ -21,17 +24,34 @@ def run( mlcube_uid (int, optional): MLCube UID. Defaults to None. """ comms = config.comms - too_many_resources = dataset_uid and mlcube_uid - no_resource = dataset_uid is None and mlcube_uid is None - if no_resource or too_many_resources: - raise InvalidArgumentError("Must provide either a dataset or mlcube") + validate_args( + benchmark_uid, + training_exp_uid, + dataset_uid, + mlcube_uid, + aggregator_uid, + ca_uid, + approval_status.value, + ) + update = {"approval_status": approval_status.value} + if benchmark_uid: + if dataset_uid: + comms.update_benchmark_dataset_association( + benchmark_uid, dataset_uid, update + ) - if dataset_uid: - comms.set_dataset_association_approval( - benchmark_uid, dataset_uid, approval_status.value - ) - - if mlcube_uid: - comms.set_mlcube_association_approval( - benchmark_uid, mlcube_uid, approval_status.value - ) + if mlcube_uid: + comms.update_benchmark_model_association( + benchmark_uid, mlcube_uid, update + ) + if training_exp_uid: + if dataset_uid: + comms.update_training_dataset_association( + benchmark_uid, dataset_uid, update + ) + if aggregator_uid: + comms.update_training_aggregator_association( + benchmark_uid, mlcube_uid, update + ) + if ca_uid: + comms.update_training_ca_association(benchmark_uid, mlcube_uid, update) diff --git a/cli/medperf/commands/association/association.py b/cli/medperf/commands/association/association.py index fa69682ed..2e3bca2ac 100644 --- a/cli/medperf/commands/association/association.py +++ b/cli/medperf/commands/association/association.py @@ -1,5 +1,4 @@ import typer -from typing import Optional import medperf.config as config from medperf.decorators import clean_except @@ -13,22 +12,47 @@ @app.command("ls") @clean_except -def list(filter: Optional[str] = typer.Argument(None)): +def list( + benchmark: bool = typer.Option(False, "-b", help="list benchmark associations"), + training_exp: bool = typer.Option(False, "-t", help="list training associations"), + dataset: bool = typer.Option(False, "-d", help="list dataset associations"), + mlcube: bool = typer.Option(False, "-m", help="list mlcube associations"), + aggregator: bool = typer.Option(False, "-a", help="list aggregator associations"), + ca: bool = typer.Option(False, "-c", help="list ca associations"), + approval_status: str = typer.Option( + None, "--approval-status", help="Approval status" + ), +): """Display all associations related to the current user. Args: filter (str, optional): Filter associations by approval status. Defaults to displaying all user associations. """ - ListAssociations.run(filter) + ListAssociations.run( + benchmark, + training_exp, + dataset, + mlcube, + aggregator, + ca, + approval_status, + ) @app.command("approve") @clean_except def approve( benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), + training_exp_uid: int = typer.Option( + ..., "--training_exp", "-t", help="Training exp UID" + ), dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"), mlcube_uid: int = typer.Option(None, "--mlcube", "-m", help="MLCube UID"), + aggregator_uid: int = typer.Option( + None, "--aggregator", "-a", help="Aggregator UID" + ), + ca_uid: int = typer.Option(None, "--ca", "-c", help="CA UID"), ): """Approves an association between a benchmark and a dataset or model mlcube @@ -37,7 +61,15 @@ def approve( dataset_uid (int, optional): Dataset UID. mlcube_uid (int, optional): Model MLCube UID. """ - Approval.run(benchmark_uid, Status.APPROVED, dataset_uid, mlcube_uid) + Approval.run( + Status.APPROVED, + benchmark_uid, + training_exp_uid, + dataset_uid, + mlcube_uid, + aggregator_uid, + ca_uid, + ) config.ui.print("✅ Done!") @@ -45,8 +77,15 @@ def approve( @clean_except def reject( benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), + training_exp_uid: int = typer.Option( + ..., "--training_exp", "-t", help="Training exp UID" + ), dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"), mlcube_uid: int = typer.Option(None, "--mlcube", "-m", help="MLCube UID"), + aggregator_uid: int = typer.Option( + None, "--aggregator", "-a", help="Aggregator UID" + ), + ca_uid: int = typer.Option(None, "--ca", "-c", help="CA UID"), ): """Rejects an association between a benchmark and a dataset or model mlcube @@ -55,7 +94,15 @@ def reject( dataset_uid (int, optional): Dataset UID. mlcube_uid (int, optional): Model MLCube UID. """ - Approval.run(benchmark_uid, Status.REJECTED, dataset_uid, mlcube_uid) + Approval.run( + Status.REJECTED, + benchmark_uid, + training_exp_uid, + dataset_uid, + mlcube_uid, + aggregator_uid, + ca_uid, + ) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/association/list.py b/cli/medperf/commands/association/list.py index e210fbc26..fbe77539a 100644 --- a/cli/medperf/commands/association/list.py +++ b/cli/medperf/commands/association/list.py @@ -1,47 +1,55 @@ from tabulate import tabulate from medperf import config +from medperf.commands.association.utils import validate_args, get_associations_list class ListAssociations: @staticmethod - def run(filter: str = None): - """Get Pending association requests""" - comms = config.comms - ui = config.ui - dset_assocs = comms.get_datasets_associations() - cube_assocs = comms.get_cubes_associations() + def run( + benchmark, + training_exp, + dataset, + mlcube, + aggregator, + ca, + approval_status, + ): + """Get user association requests""" + validate_args( + benchmark, training_exp, dataset, mlcube, aggregator, ca, approval_status + ) + if training_exp: + experiment_key = "training_exp" + elif benchmark: + experiment_key = "benchmark" - # Might be worth seeing if creating an association class that encapsulates - # most of the logic here is useful - assocs = dset_assocs + cube_assocs - if filter: - filter = filter.upper() - assocs = [assoc for assoc in assocs if assoc["approval_status"] == filter] + if mlcube: + component_key = "model_mlcube" + elif dataset: + component_key = "dataset" + elif aggregator: + component_key = "aggregator" + elif ca: + component_key = "ca" + + assocs = get_associations_list(experiment_key, component_key, approval_status) assocs_info = [] for assoc in assocs: assoc_info = ( - assoc.get("dataset", None), - assoc.get("model_mlcube", None), - assoc["benchmark"], + assoc[component_key], + assoc[experiment_key], assoc["initiated_by"], assoc["approval_status"], - assoc.get("priority", None), - # NOTE: We should find a better way to show priorities, since a priority - # is better shown when listing cube associations only, of a specific - # benchmark. Maybe this is resolved after we add a general filtering - # feature to list commands. ) assocs_info.append(assoc_info) headers = [ - "Dataset UID", - "MLCube UID", - "Benchmark UID", + f"{component_key.replace('_', ' ').title()} UID", + f"{experiment_key.replace('_', ' ').title()} UID", "Initiated by", "Status", - "Priority", ] tab = tabulate(assocs_info, headers=headers) - ui.print(tab) + config.ui.print(tab) diff --git a/cli/medperf/commands/association/priority.py b/cli/medperf/commands/association/priority.py index c58db2450..760b0f4c2 100644 --- a/cli/medperf/commands/association/priority.py +++ b/cli/medperf/commands/association/priority.py @@ -19,6 +19,6 @@ def run(benchmark_uid: int, mlcube_uid: int, priority: int): raise InvalidArgumentError( "The given mlcube doesn't exist or is not associated with the benchmark" ) - config.comms.set_mlcube_association_priority( - benchmark_uid, mlcube_uid, priority + config.comms.update_benchmark_model_association( + benchmark_uid, mlcube_uid, {"priority": priority} ) diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py new file mode 100644 index 000000000..83b785c5e --- /dev/null +++ b/cli/medperf/commands/association/utils.py @@ -0,0 +1,119 @@ +from medperf.exceptions import InvalidArgumentError +from medperf import config +from pydantic.datetime_parse import parse_datetime + + +def validate_args( + benchmark, training_exp, dataset, model_mlcube, aggregator, ca, approval_status +): + training_exp = bool(training_exp) + benchmark = bool(benchmark) + dataset = bool(dataset) + model_mlcube = bool(model_mlcube) + aggregator = bool(aggregator) + ca = bool(ca) + + if approval_status is not None: + if approval_status.lower() not in ["pending", "approved", "rejected"]: + raise InvalidArgumentError( + "If provided, approval status must be one of pending, approved, or rejected" + ) + if sum([benchmark, training_exp]) != 1: + raise InvalidArgumentError( + "One training experiment or a benchmark flag must be provided" + ) + if sum([dataset, model_mlcube, aggregator, ca]) != 1: + raise InvalidArgumentError( + "One dataset, mlcube, aggregator, or ca flag must be provided" + ) + if training_exp and model_mlcube: + raise InvalidArgumentError( + "Invalid combination of arguments. There are no associations" + " between training experiments and models mlcubes" + ) + if benchmark and (ca or aggregator): + raise InvalidArgumentError( + "Invalid combination of arguments. There are no associations" + " between benchmarks and CAs or aggregators" + ) + + +def filter_latest_associations(associations, experiment_key, component_key): + """Given a list of component-experiment associations, this function + retrieves a list containing the latest association of each + experiment-component instance. + + Args: + associations (list[dict]): the list of associations + experiment_key (str): experiment identifier field in the association + component_key (str): component identifier field in the association + + Returns: + list[dict]: the list containing the latest association of each + entity instance. + """ + + associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) + latest_associations = {} + for assoc in associations: + component_id = assoc[component_key] + experiment_id = assoc[experiment_key] + latest_associations[(component_id, experiment_id)] = assoc + + latest_associations = list(latest_associations.values()) + return latest_associations + + +def get_last_component(associations, experiment_key): + associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) + experiments_component = {} + for assoc in associations: + experiment_id = assoc[experiment_key] + experiments_component[experiment_id] = assoc + + experiments_component = list(experiments_component.values()) + return experiments_component + + +def get_associations_list( + experiment_key: str, + component_key: str, + approval_status: str = None, + experiment_id: int = None, +): + comms_functions = { + "training_exp": { + "dataset": { + "user": config.comms.get_user_training_datasets_associations, + "experiment": config.comms.get_training_datasets_associations, + }, + "aggregator": config.comms.get_user_training_aggregators_associations, + "ca": config.comms.get_user_training_cas_associations, + }, + "benchmark": { + "dataset": config.comms.get_user_benchmarks_datasets_associations, + "mode_mlcube": { + "user": config.comms.get_user_benchmarks_models_associations, + "experiment": config.comms.get_benchmark_models_associations, + }, + }, + } + if experiment_id: + comms_func = comms_functions[experiment_key][component_key]["experiment"] + assocs = comms_func(experiment_id) + else: + comms_func = comms_functions[experiment_key][component_key]["user"] + assocs = comms_func() + + assocs = filter_latest_associations(assocs, experiment_key, component_key) + if component_key in ["aggregator", "ca"]: + # an experiment should only have one aggregator and/or one CA + assocs = get_last_component(assocs, experiment_key) + + if approval_status: + approval_status = approval_status.upper() + assocs = [ + assoc for assoc in assocs if assoc["approval_status"] == approval_status + ] + + return assocs diff --git a/cli/medperf/commands/ca/associate.py b/cli/medperf/commands/ca/associate.py new file mode 100644 index 000000000..9bf439b6b --- /dev/null +++ b/cli/medperf/commands/ca/associate.py @@ -0,0 +1,34 @@ +from medperf import config +from medperf.entities.ca import CA +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt +from medperf.exceptions import InvalidArgumentError + + +class AssociateCA: + @staticmethod + def run(training_exp_id: int, ca_uid: int, approved=False): + """Associates an ca with a training experiment + + Args: + ca_uid (int): UID of the registered ca to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + ca = CA.get(ca_uid) + if ca.id is None: + msg = "The provided ca is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_id) + msg = "Please confirm that you would like to associate" + msg += f" the ca {ca.name} with the training exp {training_exp.name}." + msg += " [Y/n]" + + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating ca training association") + comms.associate_training_ca(ca.id, training_exp_id) + else: + ui.print("CA association operation cancelled.") diff --git a/cli/medperf/commands/ca/ca.py b/cli/medperf/commands/ca/ca.py new file mode 100644 index 000000000..167a072c9 --- /dev/null +++ b/cli/medperf/commands/ca/ca.py @@ -0,0 +1,102 @@ +from typing import Optional +from medperf.entities.ca import CA +import typer + +import medperf.config as config +from medperf.decorators import clean_except +from medperf.commands.ca.submit import SubmitCA +from medperf.commands.ca.associate import AssociateCA + +from medperf.commands.list import EntityList +from medperf.commands.view import EntityView + +app = typer.Typer() + + +@app.command("submit") +@clean_except +def submit( + name: str = typer.Option(..., "--name", "-n", help="Name of the ca"), + config_path: str = typer.Option( + ..., + "--config-path", + "-c", + help="Path to the configuration file (JSON) of the CA", + ), + ca_mlcube: int = typer.Option(..., "--ca-mlcube", help="CA MLCube UID"), + client_mlcube: int = typer.Option( + ..., + "--client-mlcube", + help="MLCube UID to be used by clients to get a cert", + ), + server_mlcube: int = typer.Option( + ..., + "--server-mlcube", + help="MLCube UID to be used by servers to get a cert", + ), +): + """Submits a ca""" + SubmitCA.run(name, config_path, ca_mlcube, client_mlcube, server_mlcube) + config.ui.print("✅ Done!") + + +@app.command("associate") +@clean_except +def associate( + ca_id: int = typer.Option(..., "--ca_id", "-a", help="UID of CA to associate with"), + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of training exp to associate with" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), +): + """Associates a CA with a training experiment.""" + AssociateCA.run(ca_id, training_exp_id, approved=approval) + config.ui.print("✅ Done!") + + +@app.command("ls") +@clean_except +def list( + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered CAs" + ), + mine: bool = typer.Option(False, "--mine", help="Get current-user CAs"), +): + """List CAs""" + EntityList.run( + CA, + fields=["UID", "Name", "Address", "Port"], + unregistered=unregistered, + mine_only=mine, + ) + + +@app.command("view") +@clean_except +def view( + entity_id: Optional[int] = typer.Argument(None, help="Benchmark ID"), + format: str = typer.Option( + "yaml", + "-f", + "--format", + help="Format to display contents. Available formats: [yaml, json]", + ), + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", + ), + mine: bool = typer.Option( + False, + "--mine", + help="Display current-user benchmarks if benchmark ID is not provided", + ), + output: str = typer.Option( + None, + "--output", + "-o", + help="Output file to store contents. If not provided, the output will be displayed", + ), +): + """Displays the information of one or more CAs""" + EntityView.run(entity_id, CA, format, unregistered, mine, output) diff --git a/cli/medperf/commands/ca/submit.py b/cli/medperf/commands/ca/submit.py new file mode 100644 index 000000000..cbb3f0348 --- /dev/null +++ b/cli/medperf/commands/ca/submit.py @@ -0,0 +1,65 @@ +import medperf.config as config +from medperf.entities.ca import CA +from medperf.utils import remove_path +from medperf.entities.cube import Cube + + +class SubmitCA: + @classmethod + def run( + cls, + name: str, + config_path: str, + ca_mlcube: int, + client_mlcube: int, + server_mlcube: int, + ): + """Submits a new ca to the medperf platform + Args: + name (str): ca name + config_path (dict): ca config + ca_mlcube (int): ca_mlcube mlcube uid + client_mlcube (int): client_mlcube mlcube uid + server_mlcube (int): server_mlcube mlcube uid + """ + ui = config.ui + submission = cls(name, config_path, ca_mlcube, client_mlcube, server_mlcube) + + with ui.interactive(): + ui.text = "Submitting CA to MedPerf" + submission.validate_ca_cubes() + updated_benchmark_body = submission.submit() + ui.print("Uploaded") + submission.write(updated_benchmark_body) + + def __init__( + self, + name: str, + config_path: str, + ca_mlcube: int, + client_mlcube: int, + server_mlcube: int, + ): + self.ui = config.ui + self.ca = CA( + name=name, + config=config_path, + ca_mlcube=ca_mlcube, + client_mlcube=client_mlcube, + server_mlcube=server_mlcube, + ) + config.tmp_paths.append(self.ca.path) + + def validate_ca_cubes(self): + Cube.get(self.ca.ca_mlcube) + Cube.get(self.ca.client_mlcube) + Cube.get(self.ca.server_mlcube) + + def submit(self): + updated_body = self.ca.upload() + return updated_body + + def write(self, updated_body): + remove_path(self.ca.path) + ca = CA(**updated_body) + ca.write() diff --git a/cli/medperf/commands/certificate/certificate.py b/cli/medperf/commands/certificate/certificate.py new file mode 100644 index 000000000..3a5e387fc --- /dev/null +++ b/cli/medperf/commands/certificate/certificate.py @@ -0,0 +1,38 @@ +import typer + +import medperf.config as config +from medperf.decorators import clean_except +from medperf.commands.certificate.client_certificate import GetUserCertificate +from medperf.commands.certificate.server_certificate import GetServerCertificate + +app = typer.Typer() + + +@app.command("get_client_certificate") +@clean_except +def get_client_certificate( + training_exp_id: int = typer.Option( + ..., + "--training_exp_id", + "-t", + help="UID of training exp which you intend to be a part of", + ), +): + """get a client certificate""" + GetUserCertificate.run(training_exp_id) + config.ui.print("✅ Done!") + + +@app.command("get_server_certificate") +@clean_except +def get_server_certificate( + training_exp_id: int = typer.Option( + ..., + "--training_exp_id", + "-t", + help="UID of training exp which you intend to be a part of", + ), +): + """get a server certificate""" + GetServerCertificate.run(training_exp_id) + config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/certificate/client_certificate.py b/cli/medperf/commands/certificate/client_certificate.py new file mode 100644 index 000000000..2103aca6c --- /dev/null +++ b/cli/medperf/commands/certificate/client_certificate.py @@ -0,0 +1,18 @@ +from medperf.entities.ca import CA +from medperf.account_management import get_medperf_user_data +from medperf.certificates import get_client_cert +from medperf.utils import get_pki_assets_path +import os + + +class GetUserCertificate: + @staticmethod + def run(training_exp_id: int): + """get user cert""" + ca = CA.from_experiment(training_exp_id) + email = get_medperf_user_data()["email"] + output_path = get_pki_assets_path(email, ca.name) + if os.path.exists(output_path) and os.listdir(output_path): + # TODO? + raise ValueError("already") + get_client_cert(ca, email, output_path) diff --git a/cli/medperf/commands/certificate/server_certificate.py b/cli/medperf/commands/certificate/server_certificate.py new file mode 100644 index 000000000..68e45b2d6 --- /dev/null +++ b/cli/medperf/commands/certificate/server_certificate.py @@ -0,0 +1,19 @@ +from medperf.entities.ca import CA +from medperf.entities.aggregator import Aggregator +from medperf.certificates import get_server_cert +from medperf.utils import get_pki_assets_path +import os + + +class GetServerCertificate: + @staticmethod + def run(training_exp_id: int): + """get server cert""" + ca = CA.from_experiment(training_exp_id) + aggregator = Aggregator.from_experiment(training_exp_id) + address = aggregator.address + output_path = get_pki_assets_path(address, ca.name) + if os.path.exists(output_path) and os.listdir(output_path): + # TODO? + raise ValueError("already") + get_server_cert(ca, address, output_path) diff --git a/cli/medperf/commands/dataset/associate.py b/cli/medperf/commands/dataset/associate.py index 84359fd1d..a9de70b80 100644 --- a/cli/medperf/commands/dataset/associate.py +++ b/cli/medperf/commands/dataset/associate.py @@ -1,52 +1,31 @@ -from medperf import config -from medperf.entities.dataset import Dataset -from medperf.entities.benchmark import Benchmark -from medperf.utils import dict_pretty_print, approval_prompt -from medperf.commands.result.create import BenchmarkExecution +from medperf.commands.dataset.associate_benchmark import AssociateBenchmarkDataset +from medperf.commands.dataset.associate_training import AssociateTrainingDataset from medperf.exceptions import InvalidArgumentError class AssociateDataset: @staticmethod - def run(data_uid: int, benchmark_uid: int, approved=False, no_cache=False): - """Associates a registered dataset with a benchmark - - Args: - data_uid (int): UID of the registered dataset to associate - benchmark_uid (int): UID of the benchmark to associate with - """ - comms = config.comms - ui = config.ui - dset = Dataset.get(data_uid) - if dset.id is None: - msg = "The provided dataset is not registered." - raise InvalidArgumentError(msg) - - benchmark = Benchmark.get(benchmark_uid) - - if dset.data_preparation_mlcube != benchmark.data_preparation_mlcube: + def run( + data_uid: int, + benchmark_uid: int = None, + training_exp_uid: int = None, + approved=False, + no_cache=False, + ): + """Associates a dataset with a benchmark or a training exp""" + too_many_resources = benchmark_uid and training_exp_uid + no_resource = benchmark_uid is None and training_exp_uid is None + if no_resource or too_many_resources: raise InvalidArgumentError( - "The specified dataset wasn't prepared for this benchmark" + "Must provide either a benchmark or a training experiment" ) - - result = BenchmarkExecution.run( - benchmark_uid, - data_uid, - [benchmark.reference_model_mlcube], - no_cache=no_cache, - )[0] - ui.print("These are the results generated by the compatibility test. ") - ui.print("This will be sent along the association request.") - ui.print("They will not be part of the benchmark.") - dict_pretty_print(result.results) - - msg = "Please confirm that you would like to associate" - msg += f" the dataset {dset.name} with the benchmark {benchmark.name}." - msg += " [Y/n]" - approved = approved or approval_prompt(msg) - if approved: - ui.print("Generating dataset benchmark association") - metadata = {"test_result": result.results} - comms.associate_dset(dset.id, benchmark_uid, metadata) - else: - ui.print("Dataset association operation cancelled.") + if benchmark_uid: + AssociateBenchmarkDataset.run( + data_uid, benchmark_uid, approved=approved, no_cache=no_cache + ) + if training_exp_uid: + if no_cache: + raise InvalidArgumentError( + "no_cache argument is only valid when associating with a benchmark" + ) + AssociateTrainingDataset.run(data_uid, benchmark_uid, approved=approved) diff --git a/cli/medperf/commands/dataset/associate_benchmark.py b/cli/medperf/commands/dataset/associate_benchmark.py new file mode 100644 index 000000000..9b937c36d --- /dev/null +++ b/cli/medperf/commands/dataset/associate_benchmark.py @@ -0,0 +1,52 @@ +from medperf import config +from medperf.entities.dataset import Dataset +from medperf.entities.benchmark import Benchmark +from medperf.utils import dict_pretty_print, approval_prompt +from medperf.commands.result.create import BenchmarkExecution +from medperf.exceptions import InvalidArgumentError + + +class AssociateBenchmarkDataset: + @staticmethod + def run(data_uid: int, benchmark_uid: int, approved=False, no_cache=False): + """Associates a registered dataset with a benchmark + + Args: + data_uid (int): UID of the registered dataset to associate + benchmark_uid (int): UID of the benchmark to associate with + """ + comms = config.comms + ui = config.ui + dset = Dataset.get(data_uid) + if dset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + benchmark = Benchmark.get(benchmark_uid) + + if dset.data_preparation_mlcube != benchmark.data_preparation_mlcube: + raise InvalidArgumentError( + "The specified dataset wasn't prepared for this benchmark" + ) + + result = BenchmarkExecution.run( + benchmark_uid, + data_uid, + [benchmark.reference_model_mlcube], + no_cache=no_cache, + )[0] + ui.print("These are the results generated by the compatibility test. ") + ui.print("This will be sent along the association request.") + ui.print("They will not be part of the benchmark.") + dict_pretty_print(result.results) + + msg = "Please confirm that you would like to associate" + msg += f" the dataset {dset.name} with the benchmark {benchmark.name}." + msg += " [Y/n]" + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating dataset benchmark association") + metadata = {"test_result": result.results} + comms.associate_benchmark_dataset(dset.id, benchmark_uid, metadata) + else: + ui.print("Dataset association operation cancelled.") diff --git a/cli/medperf/commands/dataset/associate_training.py b/cli/medperf/commands/dataset/associate_training.py new file mode 100644 index 000000000..7a3565089 --- /dev/null +++ b/cli/medperf/commands/dataset/associate_training.py @@ -0,0 +1,39 @@ +from medperf import config +from medperf.entities.dataset import Dataset +from medperf.entities.training_exp import TrainingExp +from medperf.utils import approval_prompt +from medperf.exceptions import InvalidArgumentError + + +class AssociateTrainingDataset: + @staticmethod + def run(data_uid: int, training_exp_uid: int, approved=False): + """Associates a dataset with a training experiment + + Args: + data_uid (int): UID of the registered dataset to associate + training_exp_uid (int): UID of the training experiment to associate with + """ + comms = config.comms + ui = config.ui + dset: Dataset = Dataset.get(data_uid) + if dset.id is None: + msg = "The provided dataset is not registered." + raise InvalidArgumentError(msg) + + training_exp = TrainingExp.get(training_exp_uid) + + if dset.data_preparation_mlcube != training_exp.data_preparation_mlcube: + raise InvalidArgumentError( + "The specified dataset wasn't prepared for this experiment" + ) + + msg = "Please confirm that you would like to associate" + msg += f" the dataset {dset.name} with the training experiment {training_exp.name}." + msg += " [Y/n]" + approved = approved or approval_prompt(msg) + if approved: + ui.print("Generating dataset training experiment association") + comms.associate_training_dataset(dset.id, training_exp_uid) + else: + ui.print("Dataset association operation cancelled.") diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index fc18022ac..d3243ec0d 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -124,20 +124,23 @@ def associate( ..., "--data_uid", "-d", help="Registered Dataset UID" ), benchmark_uid: int = typer.Option( - ..., "--benchmark_uid", "-b", help="Benchmark UID" + None, "--benchmark_uid", "-b", help="Benchmark UID" + ), + training_exp_uid: int = typer.Option( + None, "--training_exp_uid", "-t", help="Training experiment UID" ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), no_cache: bool = typer.Option( False, "--no-cache", - help="Execute the test even if results already exist", + help="Execute the benchmark association test even if results already exist", ), ): - """Associate a registered dataset with a specific benchmark. - The dataset and benchmark must share the same data preparation cube. - """ + """Associate a registered dataset with a specific benchmark or experiment.""" ui = config.ui - AssociateDataset.run(data_uid, benchmark_uid, approved=approval, no_cache=no_cache) + AssociateDataset.run( + data_uid, benchmark_uid, training_exp_uid, approved=approval, no_cache=no_cache + ) ui.print("✅ Done!") diff --git a/cli/medperf/commands/mlcube/associate.py b/cli/medperf/commands/mlcube/associate.py index 8307caade..9ee7317cc 100644 --- a/cli/medperf/commands/mlcube/associate.py +++ b/cli/medperf/commands/mlcube/associate.py @@ -40,6 +40,6 @@ def run( if approved: ui.print("Generating mlcube benchmark association") metadata = {"test_result": results} - comms.associate_cube(cube_uid, benchmark_uid, metadata) + comms.associate_benchmark_model(cube_uid, benchmark_uid, metadata) else: ui.print("MLCube association operation cancelled") diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 760dddc94..1b8622810 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -101,6 +101,8 @@ def validate(self): if dset_prep_cube != bmark_prep_cube: msg = "The provided dataset is not compatible with the specified benchmark." raise InvalidArgumentError(msg) + # TODO: there is no check if dataset is associated with the benchmark + # Note that if it is present, this will break dataset association logic def prepare_models(self): if self.models_input_file: diff --git a/cli/medperf/commands/training/approve.py b/cli/medperf/commands/training/approve.py deleted file mode 100644 index 8c42fd127..000000000 --- a/cli/medperf/commands/training/approve.py +++ /dev/null @@ -1,38 +0,0 @@ -from medperf import config -from medperf.exceptions import InvalidArgumentError - - -class TrainingAssociationApproval: - @staticmethod - def run( - training_exp_id: int, - approval_status, - data_uid: int = None, - aggregator: int = None, - ): - """Sets approval status for an association between a benchmark and a dataset or mlcube - - Args: - benchmark_uid (int): Benchmark UID. - approval_status (str): Desired approval status to set for the association. - comms (Comms): Instance of Comms interface. - ui (UI): Instance of UI interface. - dataset_uid (int, optional): Dataset UID. Defaults to None. - mlcube_uid (int, optional): MLCube UID. Defaults to None. - """ - comms = config.comms - too_many_resources = data_uid and aggregator - no_resource = data_uid is None and aggregator is None - if no_resource or too_many_resources: - raise InvalidArgumentError("Must provide either a dataset or aggregator") - - if data_uid: - # TODO: show CSR and ask for confirmation - comms.set_training_dataset_association_approval( - training_exp_id, data_uid, approval_status.value - ) - - if aggregator: - comms.set_aggregator_association_approval( - training_exp_id, aggregator, approval_status.value - ) diff --git a/cli/medperf/commands/training/associate.py b/cli/medperf/commands/training/associate.py deleted file mode 100644 index 0493d3004..000000000 --- a/cli/medperf/commands/training/associate.py +++ /dev/null @@ -1,45 +0,0 @@ -from medperf import config -from medperf.entities.dataset import Dataset -from medperf.entities.training_exp import TrainingExp -from medperf.utils import approval_prompt, generate_data_csr -from medperf.exceptions import InvalidArgumentError - - -class DatasetTrainingAssociation: - @staticmethod - def run(training_exp_id: int, data_uid: int, approved=False): - """Associates a registered dataset with a benchmark - - Args: - data_uid (int): UID of the registered dataset to associate - benchmark_uid (int): UID of the benchmark to associate with - """ - comms = config.comms - ui = config.ui - dset = Dataset.get(data_uid) - if dset.id is None: - msg = "The provided dataset is not registered." - raise InvalidArgumentError(msg) - - training_exp = TrainingExp.get(training_exp_id) - - if dset.data_preparation_mlcube != training_exp.data_preparation_mlcube: - raise InvalidArgumentError( - "The specified dataset wasn't prepared for this benchmark" - ) - - email = "" # TODO - csr, csr_hash = generate_data_csr(email, data_uid, training_exp_id) - msg = "Please confirm that you would like to associate" - msg += f" the dataset {dset.name} with the training exp {training_exp.name}." - msg += f" The certificate signing request hash is: {csr_hash}" - msg += " [Y/n]" - - approved = approved or approval_prompt(msg) - if approved: - ui.print("Generating dataset training association") - # TODO: delete keys if upload fails - # check if on failure, other (possible) request will overwrite key - comms.associate_training_dset(dset.id, training_exp_id, csr) - else: - ui.print("Dataset association operation cancelled.") diff --git a/cli/medperf/commands/training/close_event.py b/cli/medperf/commands/training/close_event.py new file mode 100644 index 000000000..2a922d97b --- /dev/null +++ b/cli/medperf/commands/training/close_event.py @@ -0,0 +1,59 @@ +import os +from medperf.entities.training_exp import TrainingExp +from medperf.entities.event import TrainingEvent +from medperf.utils import approval_prompt, dict_pretty_print +from medperf.exceptions import CleanExit, InvalidArgumentError +from medperf import config +import yaml + + +class CloseEvent: + """Used for both event cancellation (with custom report path) and for event closing + (with the expected report path generated by the aggregator)""" + + @classmethod + def run(cls, training_exp_id: int, report_path: str = None, approval: bool = False): + submission = cls(training_exp_id, report_path, approval) + submission.prepare() + submission.validate() + submission.read_report() + submission.submit() + submission.write() + + def __init__(self, training_exp_id: int, report_path: str, approval: bool): + self.training_exp_id = training_exp_id + self.approved = approval + self.report_path = report_path + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.event = TrainingEvent.from_experiment(self.training_exp_id) + self.report_path = self.report_path or self.event.report_path + + def validate(self): + if self.event.finished: + raise InvalidArgumentError("This experiment has already finished") + if not os.path.exists(self.report_path): + raise InvalidArgumentError(f"Report {self.report_path} does not exist.") + + def read_report(self): + with open(self.report_path) as f: + self.report = yaml.safe_load(f) + + def submit(self): + self.event.report = self.report + body = {"finished": True, "report": self.report} + dict_pretty_print(self.report) + msg = ( + f"You are about to close the event of training experiment {self.training_exp.name}." + " This will be the submitted report. Do you confirm? [Y/n] " + ) + self.approved = self.approved or approval_prompt(msg) + + if self.approved: + config.comms.update_training_event(self.event.id, body) + return + raise CleanExit("Event closing cancelled") + + def write(self): + self.event.write() diff --git a/cli/medperf/commands/training/list_assocs.py b/cli/medperf/commands/training/list_assocs.py deleted file mode 100644 index 1b266cbf1..000000000 --- a/cli/medperf/commands/training/list_assocs.py +++ /dev/null @@ -1,41 +0,0 @@ -from tabulate import tabulate - -from medperf import config - - -class ListTrainingAssociations: - @staticmethod - def run(filter: str = None): - """Get training association requests""" - comms = config.comms - ui = config.ui - dset_assocs = comms.get_training_datasets_associations() - agg_assocs = comms.get_aggregators_associations() - - # Might be worth seeing if creating an association class that encapsulates - # most of the logic here is useful - assocs = dset_assocs + agg_assocs - if filter: - filter = filter.upper() - assocs = [assoc for assoc in assocs if assoc["approval_status"] == filter] - - assocs_info = [] - for assoc in assocs: - assoc_info = ( - assoc.get("dataset", None), - assoc.get("aggregator", None), - assoc["training_exp"], - assoc["initiated_by"], - assoc["approval_status"], - ) - assocs_info.append(assoc_info) - - headers = [ - "Dataset UID", - "Aggregator UID", - "TrainingExp UID", - "Initiated by", - "Status", - ] - tab = tabulate(assocs_info, headers=headers) - ui.print(tab) diff --git a/cli/medperf/commands/training/lock.py b/cli/medperf/commands/training/lock.py deleted file mode 100644 index eca813ed6..000000000 --- a/cli/medperf/commands/training/lock.py +++ /dev/null @@ -1,23 +0,0 @@ -from medperf import config -from medperf.entities.training_exp import TrainingExp - - -class LockTrainingExp: - @staticmethod - def run(training_exp_id: int): - """Sets approval status for an association between a benchmark and a dataset or mlcube - - Args: - benchmark_uid (int): Benchmark UID. - approval_status (str): Desired approval status to set for the association. - comms (Comms): Instance of Comms interface. - ui (UI): Instance of UI interface. - dataset_uid (int, optional): Dataset UID. Defaults to None. - mlcube_uid (int, optional): MLCube UID. Defaults to None. - """ - # TODO: this logic will be refactored when we merge entity edit PR - comms = config.comms - comms.set_experiment_as_operational(training_exp_id) - # update training experiment - training_exp = TrainingExp.get(training_exp_id) - training_exp.write() diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py index f47904d91..a7d48d26c 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/training/run.py @@ -1,36 +1,33 @@ -import os from medperf import config +from medperf.account_management.account_management import get_medperf_user_data +from medperf.entities.ca import CA +from medperf.entities.event import TrainingEvent from medperf.exceptions import InvalidArgumentError from medperf.entities.training_exp import TrainingExp from medperf.entities.dataset import Dataset from medperf.entities.cube import Cube -from medperf.entities.aggregator import Aggregator -from medperf.utils import get_dataset_common_name +from medperf.utils import get_pki_assets_path, get_participant_label +from medperf.certificates import trust class TrainingExecution: @classmethod def run(cls, training_exp_id: int, data_uid: int): - """Sets approval status for an association between a benchmark and a dataset or mlcube + """Starts the aggregation server of a training experiment Args: - benchmark_uid (int): Benchmark UID. - approval_status (str): Desired approval status to set for the association. - comms (Comms): Instance of Comms interface. - ui (UI): Instance of UI interface. - dataset_uid (int, optional): Dataset UID. Defaults to None. - mlcube_uid (int, optional): MLCube UID. Defaults to None. + training_exp_id (int): Training experiment UID. """ execution = cls(training_exp_id, data_uid) execution.prepare() execution.validate() - execution.prepare_data_cert() - execution.prepare_network_config() - execution.prepare_cube() + execution.prepare_training_cube() + execution.prepare_plan() + execution.prepare_pki_assets() with config.ui.interactive(): execution.run_experiment() - def __init__(self, training_exp_id, data_uid) -> None: + def __init__(self, training_exp_id: int, data_uid: int) -> None: self.training_exp_id = training_exp_id self.data_uid = data_uid self.ui = config.ui @@ -38,46 +35,29 @@ def __init__(self, training_exp_id, data_uid) -> None: def prepare(self): self.training_exp = TrainingExp.get(self.training_exp_id) self.ui.print(f"Training Execution: {self.training_exp.name}") + self.event = TrainingEvent.from_experiment(self.training_exp_id) self.dataset = Dataset.get(self.data_uid) + self.user_email: str = get_medperf_user_data()["email"] def validate(self): if self.dataset.id is None: msg = "The provided dataset is not registered." raise InvalidArgumentError(msg) - if self.dataset.id not in self.training_exp.datasets: - msg = "The provided dataset is not associated." + if self.dataset.state != "OPERATION": + msg = "The provided dataset is not operational." raise InvalidArgumentError(msg) - if self.training_exp.state != "OPERATION": - msg = "The provided training exp is not operational." + if self.event.finished(): + msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) - def prepare_data_cert(self): - association = config.comms.get_training_dataset_association( - self.training_exp.id, self.dataset.id - ) - cert = association["certificate"] - cert_folder = os.path.join( - config.training_folder, - str(self.training_exp.id), - config.data_cert_folder, - str(self.dataset.id), - ) - os.makedirs(cert_folder, exist_ok=True) - cert_file = os.path.join(cert_folder, "cert.crt") - with open(cert_file, "w") as f: - f.write(cert) - - self.data_cert_path = cert_folder - - def prepare_network_config(self): - aggregator = config.comms.get_experiment_aggregator(self.training_exp.id) - aggregator = Aggregator.get(aggregator["id"]) - self.network_config_path = aggregator.network_config_path + if self.dataset.id not in self.training_exp.get_datasets_uids(): + msg = "The provided dataset is not associated." + raise InvalidArgumentError(msg) - def prepare_cube(self): - self.cube = self.__get_cube(self.training_exp.fl_mlcube, "training") + def prepare_training_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_mlcube, "FL") def __get_cube(self, uid: int, name: str) -> Cube: self.ui.text = f"Retrieving {name} cube" @@ -86,22 +66,26 @@ def __get_cube(self, uid: int, name: str) -> Cube: self.ui.print(f"> {name} cube download complete") return cube - def run_experiment(self): - task = "train" - dataset_cn = get_dataset_common_name("", self.dataset.id, self.training_exp.id) - env_dict = {"COLLABORATOR_CN": dataset_cn} + def prepare_plan(self): + self.training_exp.prepare_plan() - # just for now create some output folders (TODO) - out_logs = os.path.join(self.training_exp.path, "data_logs", dataset_cn) - os.makedirs(out_logs, exist_ok=True) + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + def run_experiment(self): + participant_label = get_participant_label(self.user_email, self.dataset.id) + env_dict = {"COLLABORATOR_CN": participant_label} params = { "data_path": self.dataset.data_path, "labels_path": self.dataset.labels_path, - "node_cert_folder": self.data_cert_path, - "ca_cert_folder": self.training_exp.cert_path, - "network_config": self.network_config_path, - "output_logs": out_logs, + "node_cert_folder": self.dataset_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "output_logs": self.event.out_logs, } - self.ui.text = "Training" - self.cube.run(task=task, env_dict=env_dict, **params) + + self.ui.text = "Running Training" + self.cube.run(task="train", env_dict=env_dict, **params) diff --git a/cli/medperf/commands/training/set_plan.py b/cli/medperf/commands/training/set_plan.py new file mode 100644 index 000000000..f94123f05 --- /dev/null +++ b/cli/medperf/commands/training/set_plan.py @@ -0,0 +1,86 @@ +import medperf.config as config +from medperf.entities.aggregator import Aggregator +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.exceptions import CleanExit, InvalidArgumentError +from medperf.utils import approval_prompt, dict_pretty_print, generate_tmp_path +import os + +import yaml + + +class SetPlan: + @classmethod + def run( + cls, training_exp_id: int, training_config_path: str, approval: bool = False + ): + """Creates and submits the training plan + Args: + training_exp_id (int): training experiment + training_config_path (str): path to a training config file + approval (bool): skip approval + """ + planset = cls(training_exp_id, training_config_path, approval) + planset.validate() + planset.prepare() + planset.create_plan() + planset.update() + planset.write() + + def __init__(self, training_exp_id: int, training_config_path: str, approval: bool): + self.ui = config.ui + self.training_exp_id = training_exp_id + self.training_config_path = training_config_path + self.approval = approval + self.plan_out_path = generate_tmp_path() + + def validate(self): + if not os.path.exists(self.training_config_path): + raise InvalidArgumentError("Provided training config path does not exist") + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.aggregator = Aggregator.from_experiment(self.training_exp_id) + self.mlcube = self.__get_cube(self.training_exp.fl_mlcube, "FL") + self.aggregator.prepare_config() + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = f"Retrieving {name} cube" + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def create_plan(self): + """Auto-generates dataset UIDs for both input and output paths""" + params = { + "training_config_path": self.training_config_path, + "aggregator_config_path": self.aggregator.config_path, + "plan_path": self.plan_out_path, + } + self.mlcube.run("generate_plan", **params) + + def update(self): + with open(self.plan_out_path) as f: + plan = yaml.safe_load(f) + self.training_exp.plan = plan + body = {"plan": plan} + dict_pretty_print(body) + msg = ( + "This is the training plan that will be submitted and used by the participants." + " Do you confirm?[Y/n] " + ) + self.approved = self.approved or approval_prompt(msg) + + if self.approved: + config.comms.update_training_exp(self.training_exp.id, body) + return + + raise CleanExit("Setting the training plan was cancelled") + + def write(self) -> str: + """Writes the registration into disk + Args: + filename (str, optional): name of the file. Defaults to config.reg_file. + """ + self.training_exp.write() diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py new file mode 100644 index 000000000..13853bef6 --- /dev/null +++ b/cli/medperf/commands/training/start_event.py @@ -0,0 +1,58 @@ +from medperf.entities.training_exp import TrainingExp +from medperf.entities.event import TrainingEvent +from medperf.utils import approval_prompt, dict_pretty_print, get_participant_label +from medperf.exceptions import CleanExit, InvalidArgumentError + + +class StartEvent: + @classmethod + def run(cls, training_exp_id: int, approval: bool = False): + submission = cls(training_exp_id, approval) + submission.prepare() + submission.validate() + submission.create_participants_list() + submission.submit() + submission.write() + + def __init__(self, training_exp_id: int, approval): + self.training_exp_id = training_exp_id + self.approved = approval + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + + def validate(self): + if self.training_exp.approval_status != "APPROVED": + raise InvalidArgumentError("This experiment has not been approved yet") + + def create_participants_list(self): + datasets_with_users = TrainingExp.get_datasets_with_users(self.training_exp_id) + participants_list = {} + for dataset in datasets_with_users: + user_email = dataset["owner"]["email"] + data_id = dataset["id"] + participant_label = get_participant_label(user_email, data_id) + participant_common_name = user_email + participants_list[participant_label] = participant_common_name + self.participants_list = participants_list + + def submit(self): + dict_pretty_print(self.participants_list) + msg = ( + f"You are about to start an event for the training experiment {self.training_exp.name}." + " This is the list of participants (participant label, participant common name)" + " that will be able to participate in your training experiment. Do you confirm? [Y/n] " + ) + self.approved = self.approved or approval_prompt(msg) + + self.event = TrainingEvent( + training_exp=self.training_exp_id, participants=self.participants_list + ) + if self.approved: + updated_body = self.event.upload() + return updated_body + + raise CleanExit("Event creation cancelled") + + def write(self, updated_body): + self.event.write(updated_body) diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py index 5a662b450..7da6a96ba 100644 --- a/cli/medperf/commands/training/submit.py +++ b/cli/medperf/commands/training/submit.py @@ -1,5 +1,3 @@ -import os - import medperf.config as config from medperf.entities.training_exp import TrainingExp from medperf.entities.cube import Cube @@ -41,9 +39,7 @@ def __init__(self, training_exp_info: dict): def get_mlcube(self): mlcube_id = self.training_exp.fl_mlcube - cube = Cube.get(mlcube_id) - # TODO: do we want to download run files? - cube.download_run_files() + Cube.get(mlcube_id) def submit(self): updated_body = self.training_exp.upload() diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index ac1af744f..5efa629a4 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -1,6 +1,5 @@ from typing import Optional from medperf.entities.training_exp import TrainingExp -from medperf.enums import Status import typer import medperf.config as config @@ -8,10 +7,9 @@ from medperf.commands.training.submit import SubmitTrainingExp from medperf.commands.training.run import TrainingExecution -from medperf.commands.training.lock import LockTrainingExp -from medperf.commands.training.associate import DatasetTrainingAssociation -from medperf.commands.training.approve import TrainingAssociationApproval -from medperf.commands.training.list_assocs import ListTrainingAssociations +from medperf.commands.training.set_plan import SetPlan +from medperf.commands.training.start_event import StartEvent +from medperf.commands.training.close_event import CloseEvent from medperf.commands.list import EntityList from medperf.commands.view import EntityView @@ -26,12 +24,15 @@ def submit( ..., "--description", "-d", help="Description of the benchmark" ), docs_url: str = typer.Option("", "--docs-url", "-u", help="URL to documentation"), - prep_mlcube: int = typer.Option( - ..., "--prep-mlcube", "-p", help="prep MLCube UID" - ), + prep_mlcube: int = typer.Option(..., "--prep-mlcube", "-p", help="prep MLCube UID"), fl_mlcube: int = typer.Option( ..., "--fl-mlcube", "-m", help="Reference Model MLCube UID" ), + operational: bool = typer.Option( + False, + "--operational", + help="Submit the experiment as OPERATIONAL", + ), ): """Submits a new benchmark to the platform""" training_exp_info = { @@ -39,88 +40,85 @@ def submit( "description": description, "docs_url": docs_url, "fl_mlcube": fl_mlcube, - "demo_dataset_tarball_url": "link", # TODO later + "demo_dataset_tarball_url": "link", # TODO later "demo_dataset_tarball_hash": "hash", "demo_dataset_generated_uid": "uid", "data_preparation_mlcube": prep_mlcube, + "state": "OPERATION" if operational else "DEVELOPMENT", } SubmitTrainingExp.run(training_exp_info) config.ui.print("✅ Done!") -@app.command("lock") +@app.command("set_plan") @clean_except -def lock( +def set_plan( training_exp_id: int = typer.Option( ..., "--training_exp_id", "-t", help="UID of the desired benchmark" ), + training_config_path: str = typer.Option( + ..., "--config-path", "-c", help="config path" + ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - LockTrainingExp.run(training_exp_id) + SetPlan.run(training_exp_id, training_config_path, approval) config.ui.print("✅ Done!") -@app.command("run") +@app.command("start_event") @clean_except -def run( +def start_event( training_exp_id: int = typer.Option( ..., "--training_exp_id", "-t", help="UID of the desired benchmark" ), - data_uid: int = typer.Option( - ..., "--data_uid", "-d", help="Registered Dataset UID" - ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - TrainingExecution.run(training_exp_id, data_uid) + StartEvent.run(training_exp_id, approval) config.ui.print("✅ Done!") -@app.command("associate_dataset") +@app.command("close_event") @clean_except -def associate( +def close_event( training_exp_id: int = typer.Option( ..., "--training_exp_id", "-t", help="UID of the desired benchmark" ), - data_uid: int = typer.Option( - ..., "--data_uid", "-d", help="Registered Dataset UID" - ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - DatasetTrainingAssociation.run(training_exp_id, data_uid, approved=approval) + CloseEvent.run(training_exp_id, approval=approval) config.ui.print("✅ Done!") -@app.command("approve_association") +@app.command("cancel_event") @clean_except -def approve( +def cancel_event( training_exp_id: int = typer.Option( ..., "--training_exp_id", "-t", help="UID of the desired benchmark" ), - data_uid: int = typer.Option( - None, "--data_uid", "-d", help="Registered Dataset UID" - ), - aggregator: int = typer.Option( - None, "--aggregator", "-a", help="Registered Dataset UID" - ), + report_path: str = typer.Option(..., "--report-path", "-r", help="report path"), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - TrainingAssociationApproval.run( - training_exp_id, Status.APPROVED, data_uid, aggregator - ) + CloseEvent.run(training_exp_id, report_path=report_path, approval=approval) config.ui.print("✅ Done!") -@app.command("list_associations") +@app.command("run") @clean_except -def list(filter: Optional[str] = typer.Argument(None)): - """Display all training associations related to the current user. - - Args: - filter (str, optional): Filter training associations by approval status. - Defaults to displaying all user training associations. - """ - ListTrainingAssociations.run(filter) +def run( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + data_uid: int = typer.Option( + ..., "--data_uid", "-d", help="Registered Dataset UID" + ), +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + TrainingExecution.run(training_exp_id, data_uid) + config.ui.print("✅ Done!") @app.command("ls") diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 9840a0f57..8f61e006e 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -5,12 +5,7 @@ from medperf.enums import Status import medperf.config as config from medperf.comms.interface import Comms -from medperf.utils import ( - sanitize_json, - log_response_error, - format_errors_dict, - filter_latest_associations, -) +from medperf.utils import sanitize_json, log_response_error, format_errors_dict from medperf.exceptions import ( CommunicationError, CommunicationRetrievalError, @@ -81,6 +76,7 @@ def __get_list( page_size=config.default_page_size, offset=0, binary_reduction=False, + error_msg: str = "", ): """Retrieves a list of elements from a URL by iterating over pages until num_elements is obtained. If num_elements is None, then iterates until all elements have been retrieved. @@ -110,16 +106,15 @@ def __get_list( if not binary_reduction: log_response_error(res) details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"there was an error retrieving the current list: {details}" - ) + raise CommunicationRetrievalError(f"{error_msg}: {details}") log_response_error(res, warn=True) details = format_errors_dict(res.json()) if page_size <= 1: - raise CommunicationRetrievalError( - f"Could not retrieve list. Minimum page size achieved without success: {details}" + logging.debug( + "Could not retrieve list. Minimum page size achieved without success" ) + raise CommunicationRetrievalError(f"{error_msg}: {details}") page_size = page_size // 2 continue else: @@ -133,154 +128,203 @@ def __get_list( return el_list[:num_elements] return el_list - def __set_approval_status(self, url: str, status: str) -> requests.Response: - """Sets the approval status of a resource + def __get(self, url: str, error_msg: str) -> dict: + """self.__auth_get with error handling""" + res = self.__auth_get(url) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError(f"{error_msg}: {details}") + return res.json() + + def __post(self, url: str, json: dict, error_msg: str) -> int: + """self.__auth_post with error handling""" + res = self.__auth_post(url, json=json) + if res.status_code != 201: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRetrievalError(f"{error_msg}: {details}") + return res.json() + + def __put(self, url: str, json: dict, error_msg: str): + """self.__auth_put with error handling""" + res = self.__auth_put(url, json=json) + if res.status_code != 200: + log_response_error(res) + details = format_errors_dict(res.json()) + raise CommunicationRequestError(f"{error_msg}: {details}") + + def get_current_user(self): + """Retrieve the currently-authenticated user information""" + url = f"{self.server_url}/me/" + error_msg = "Could not get current user" + return self.__get(url, error_msg) + + # get object + def get_benchmark(self, benchmark_uid: int) -> dict: + """Retrieves the benchmark specification file from the server Args: - url (str): URL to the resource to update - status (str): approval status to set + benchmark_uid (int): uid for the desired benchmark Returns: - requests.Response: Response object returned by the update + dict: benchmark specification """ - data = {"approval_status": status} - res = self.__auth_put(url, json=data) - return res + url = f"{self.server_url}/benchmarks/{benchmark_uid}" + error_msg = "Could not retrieve benchmark" + return self.__get(url, error_msg) - def __set_state(self, url: str, state: str) -> requests.Response: - """Sets the approval status of a resource + def get_cube_metadata(self, cube_uid: int) -> dict: + """Retrieves metadata about the specified cube Args: - url (str): URL to the resource to update - status (str): approval status to set + cube_uid (int): UID of the desired cube. Returns: - requests.Response: Response object returned by the update + dict: Dictionary containing url and hashes for the cube files """ - data = {"state": state} - res = self.__auth_put(url, json=data) - return res + url = f"{self.server_url}/mlcubes/{cube_uid}/" + error_msg = "Could not retrieve mlcube" + return self.__get(url, error_msg) - def get_current_user(self): - """Retrieve the currently-authenticated user information""" - res = self.__auth_get(f"{self.server_url}/me/") - return res.json() + def get_dataset(self, dset_uid: int) -> dict: + """Retrieves a specific dataset - def get_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks in the platform. + Args: + dset_uid (int): Dataset UID Returns: - List[dict]: all benchmarks information. + dict: Dataset metadata """ - bmks = self.__get_list(f"{self.server_url}/benchmarks/") - return bmks + url = f"{self.server_url}/datasets/{dset_uid}/" + error_msg = "Could not retrieve dataset" + return self.__get(url, error_msg) - def get_benchmark(self, benchmark_uid: int) -> dict: - """Retrieves the benchmark specification file from the server + def get_result(self, result_uid: int) -> dict: + """Retrieves a specific result data Args: - benchmark_uid (int): uid for the desired benchmark + result_uid (int): Result UID Returns: - dict: benchmark specification + dict: Result metadata """ - res = self.__auth_get(f"{self.server_url}/benchmarks/{benchmark_uid}") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"the specified benchmark doesn't exist: {details}" - ) - return res.json() + url = f"{self.server_url}/results/{result_uid}/" + error_msg = "Could not retrieve result" + return self.__get(url, error_msg) - def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: - """Retrieves all the model associations of a benchmark. + def get_training_exp(self, training_exp_id: int) -> dict: + """Retrieves the training_exp specification file from the server Args: - benchmark_uid (int): UID of the desired benchmark + training_exp_id (int): uid for the desired training_exp Returns: - list[int]: List of benchmark model associations + dict: training_exp specification """ - assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models") - return filter_latest_associations(assocs, "model_mlcube") + url = f"{self.server_url}/training/{training_exp_id}" + error_msg = "Could not retrieve training experiment" + return self.__get(url, error_msg) - def get_user_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks created by the user + def get_aggregator(self, aggregator_id: int) -> dict: + """Retrieves the aggregator specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark Returns: - List[dict]: Benchmarks data + dict: benchmark specification """ - bmks = self.__get_list(f"{self.server_url}/me/benchmarks/") - return bmks + url = f"{self.server_url}/aggregators/{aggregator_id}" + error_msg = "Could not retrieve aggregator" + return self.__get(url, error_msg) - def get_cubes(self) -> List[dict]: - """Retrieves all MLCubes in the platform + def get_ca(self, ca_id: int) -> dict: + """Retrieves the aggregator specification file from the server + + Args: + benchmark_uid (int): uid for the desired benchmark Returns: - List[dict]: List containing the data of all MLCubes + dict: benchmark specification """ - cubes = self.__get_list(f"{self.server_url}/mlcubes/") - return cubes + url = f"{self.server_url}/cas/{ca_id}" + error_msg = "Could not retrieve ca" + return self.__get(url, error_msg) - def get_cube_metadata(self, cube_uid: int) -> dict: - """Retrieves metadata about the specified cube + def get_training_event(self, event_id: int) -> dict: + """Retrieves the aggregator specification file from the server Args: - cube_uid (int): UID of the desired cube. + benchmark_uid (int): uid for the desired benchmark Returns: - dict: Dictionary containing url and hashes for the cube files + dict: benchmark specification """ - res = self.__auth_get(f"{self.server_url}/mlcubes/{cube_uid}/") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"the specified cube doesn't exist {details}" - ) - return res.json() + url = f"{self.server_url}/training/events/{event_id}" + error_msg = "Could not retrieve training event" + return self.__get(url, error_msg) - def get_user_cubes(self) -> List[dict]: - """Retrieves metadata from all cubes registered by the user + # get object of an object + def get_experiment_event(self, training_exp_id: int) -> dict: + """Retrieves the training experiment's event object from the server + + Args: + training_exp_id (int): uid for the training experiment Returns: - List[dict]: List of dictionaries containing the mlcubes registration information + dict: event specification """ - cubes = self.__get_list(f"{self.server_url}/me/mlcubes/") - return cubes + url = f"{self.server_url}/training/{training_exp_id}/event/" + error_msg = "Could not retrieve training experiment event" + return self.__get(url, error_msg) - def upload_benchmark(self, benchmark_dict: dict) -> int: - """Uploads a new benchmark to the server. + def get_experiment_aggregator(self, training_exp_id: int) -> dict: + """Retrieves the training experiment's aggregator object from the server Args: - benchmark_dict (dict): benchmark_data to be uploaded + training_exp_id (int): uid for the training experiment Returns: - int: UID of newly created benchmark + dict: aggregator specification """ - res = self.__auth_post(f"{self.server_url}/benchmarks/", json=benchmark_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not upload benchmark: {details}") - return res.json() + url = f"{self.server_url}/training/{training_exp_id}/aggregator/" + error_msg = "Could not retrieve training experiment aggregator" + return self.__get(url, error_msg) - def upload_mlcube(self, mlcube_body: dict) -> int: - """Uploads an MLCube instance to the platform + def get_experiment_ca(self, training_exp_id: int) -> dict: + """Retrieves the training experiment's ca object from the server Args: - mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes + training_exp_id (int): uid for the training experiment Returns: - int: id of the created mlcube instance on the platform + dict: ca specification """ - res = self.__auth_post(f"{self.server_url}/mlcubes/", json=mlcube_body) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}") - return res.json() + url = f"{self.server_url}/training/{training_exp_id}/ca/" + error_msg = "Could not retrieve training experiment ca" + return self.__get(url, error_msg) + + # get list + def get_benchmarks(self) -> List[dict]: + """Retrieves all benchmarks in the platform. + + Returns: + List[dict]: all benchmarks information. + """ + url = f"{self.server_url}/benchmarks/" + error_msg = "Could not retrieve benchmarks" + return self.__get_list(url, error_msg=error_msg) + + def get_cubes(self) -> List[dict]: + """Retrieves all MLCubes in the platform + + Returns: + List[dict]: List containing the data of all MLCubes + """ + url = f"{self.server_url}/mlcubes/" + error_msg = "Could not retrieve mlcubes" + return self.__get_list(url, error_msg=error_msg) def get_datasets(self) -> List[dict]: """Retrieves all datasets in the platform @@ -288,82 +332,90 @@ def get_datasets(self) -> List[dict]: Returns: List[dict]: List of data from all datasets """ - dsets = self.__get_list(f"{self.server_url}/datasets/") - return dsets + url = f"{self.server_url}/datasets/" + error_msg = "Could not retrieve datasets" + return self.__get_list(url, error_msg=error_msg) - def get_dataset(self, dset_uid: int) -> dict: - """Retrieves a specific dataset + def get_results(self) -> List[dict]: + """Retrieves all results - Args: - dset_uid (int): Dataset UID + Returns: + List[dict]: List of results + """ + url = f"{self.server_url}/results/" + error_msg = "Could not retrieve results" + return self.__get_list(url, error_msg=error_msg) + + def get_training_exps(self) -> List[dict]: + """Retrieves all training_exps Returns: - dict: Dataset metadata + List[dict]: List of training_exps """ - res = self.__auth_get(f"{self.server_url}/datasets/{dset_uid}/") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"Could not retrieve the specified dataset from server: {details}" - ) - return res.json() + url = f"{self.server_url}/training/" + error_msg = "Could not retrieve training experiments" + return self.__get_list(url, error_msg=error_msg) - def get_user_datasets(self) -> dict: - """Retrieves all datasets registered by the user + def get_aggregators(self) -> List[dict]: + """Retrieves all aggregators Returns: - dict: dictionary with the contents of each dataset registration query + List[dict]: List of aggregators """ - dsets = self.__get_list(f"{self.server_url}/me/datasets/") - return dsets + url = f"{self.server_url}/aggregators/" + error_msg = "Could not retrieve aggregators" + return self.__get_list(url, error_msg=error_msg) - def upload_dataset(self, reg_dict: dict) -> int: - """Uploads registration data to the server, under the sha name of the file. + def get_cas(self) -> List[dict]: + """Retrieves all training events - Args: - reg_dict (dict): Dictionary containing registration information. + Returns: + List[dict]: List of training events + """ + url = f"{self.server_url}/training/events/" + error_msg = "Could not retrieve training events" + return self.__get_list(url, error_msg=error_msg) + + def get_training_events(self) -> List[dict]: + """Retrieves all training events Returns: - int: id of the created dataset registration. + List[dict]: List of training events """ - res = self.__auth_post(f"{self.server_url}/datasets/", json=reg_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not upload the dataset: {details}") - return res.json() + url = f"{self.server_url}/training/events/" + error_msg = "Could not retrieve training events" + return self.__get_list(url, error_msg=error_msg) - def get_results(self) -> List[dict]: - """Retrieves all results + # get user list + def get_user_cubes(self) -> List[dict]: + """Retrieves metadata from all cubes registered by the user Returns: - List[dict]: List of results + List[dict]: List of dictionaries containing the mlcubes registration information """ - res = self.__get_list(f"{self.server_url}/results") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not retrieve results: {details}") - return res.json() + url = f"{self.server_url}/me/mlcubes/" + error_msg = "Could not retrieve user mlcubes" + return self.__get_list(url, error_msg=error_msg) - def get_result(self, result_uid: int) -> dict: - """Retrieves a specific result data + def get_user_datasets(self) -> dict: + """Retrieves all datasets registered by the user - Args: - result_uid (int): Result UID + Returns: + dict: dictionary with the contents of each dataset registration query + """ + url = f"{self.server_url}/me/datasets/" + error_msg = "Could not retrieve user datasets" + return self.__get_list(url, error_msg=error_msg) + + def get_user_benchmarks(self) -> List[dict]: + """Retrieves all benchmarks created by the user Returns: - dict: Result metadata + List[dict]: Benchmarks data """ - res = self.__auth_get(f"{self.server_url}/results/{result_uid}/") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"Could not retrieve the specified result: {details}" - ) - return res.json() + url = f"{self.server_url}/me/benchmarks/" + error_msg = "Could not retrieve user benchmarks" + return self.__get_list(url, error_msg=error_msg) def get_user_results(self) -> dict: """Retrieves all results registered by the user @@ -371,275 +423,195 @@ def get_user_results(self) -> dict: Returns: dict: dictionary with the contents of each result registration query """ - results = self.__get_list(f"{self.server_url}/me/results/") - return results + url = f"{self.server_url}/me/results/" + error_msg = "Could not retrieve user results" + return self.__get_list(url, error_msg=error_msg) - def get_benchmark_results(self, benchmark_id: int) -> dict: - """Retrieves all results for a given benchmark + def get_user_training_exps(self) -> dict: + """Retrieves all training_exps registered by the user - Args: - benchmark_id (int): benchmark ID to retrieve results from + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/training/" + error_msg = "Could not retrieve user training experiments" + return self.__get_list(url, error_msg=error_msg) + + def get_user_aggregators(self) -> dict: + """Retrieves all aggregators registered by the user Returns: - dict: dictionary with the contents of each result in the specified benchmark + dict: dictionary with the contents of each result registration query """ - results = self.__get_list( - f"{self.server_url}/benchmarks/{benchmark_id}/results" - ) - return results + url = f"{self.server_url}/me/aggregators/" + error_msg = "Could not retrieve user aggregators" + return self.__get_list(url, error_msg=error_msg) - def upload_result(self, results_dict: dict) -> int: - """Uploads result to the server. + def get_user_cas(self) -> dict: + """Retrieves all cas registered by the user - Args: - results_dict (dict): Dictionary containing results information. + Returns: + dict: dictionary with the contents of each result registration query + """ + url = f"{self.server_url}/me/cas/" + error_msg = "Could not retrieve user cas" + return self.__get_list(url, error_msg=error_msg) + + def get_user_training_events(self) -> dict: + """Retrieves all training events registered by the user Returns: - int: id of the generated results entry + dict: dictionary with the contents of each result registration query """ - res = self.__auth_post(f"{self.server_url}/results/", json=results_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not upload the results: {details}") - return res.json() + url = f"{self.server_url}/me/training/events/" + error_msg = "Could not retrieve user training events" + return self.__get_list(url, error_msg=error_msg) - def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}): - """Create a Dataset Benchmark association + # get user associations list + def get_user_benchmarks_datasets_associations(self) -> List[dict]: + """Get all dataset associations related to the current user - Args: - data_uid (int): Registered dataset UID - benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. + Returns: + List[dict]: List containing all associations information """ - data = { - "dataset": data_uid, - "benchmark": benchmark_uid, - "approval_status": Status.PENDING.value, - "metadata": metadata, - } - res = self.__auth_post(f"{self.server_url}/datasets/benchmarks/", json=data) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not associate dataset to benchmark: {details}" - ) + url = f"{self.server_url}/me/datasets/associations/" + error_msg = "Could not retrieve user datasets benchmark associations" + return self.__get_list(url, error_msg=error_msg) - def associate_cube(self, cube_uid: int, benchmark_uid: int, metadata: dict = {}): - """Create an MLCube-Benchmark association + def get_user_benchmarks_models_associations(self) -> List[dict]: + """Get all cube associations related to the current user - Args: - cube_uid (int): MLCube UID - benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. + Returns: + List[dict]: List containing all associations information """ - data = { - "approval_status": Status.PENDING.value, - "model_mlcube": cube_uid, - "benchmark": benchmark_uid, - "metadata": metadata, - } - res = self.__auth_post(f"{self.server_url}/mlcubes/benchmarks/", json=data) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not associate mlcube to benchmark: {details}" - ) + url = f"{self.server_url}/me/mlcubes/associations/" + error_msg = "Could not retrieve user mlcubes benchmark associations" + return self.__get_list(url, error_msg=error_msg) - def set_dataset_association_approval( - self, benchmark_uid: int, dataset_uid: int, status: str - ): - """Approves a dataset association + def get_user_training_datasets_associations(self) -> List[dict]: + """Get all training dataset associations related to the current user - Args: - dataset_uid (int): Dataset UID - benchmark_uid (int): Benchmark UID - status (str): Approval status to set for the association - """ - url = f"{self.server_url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" - res = self.__set_approval_status(url, status) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not approve association between dataset {dataset_uid} and benchmark {benchmark_uid}: {details}" - ) - - def set_mlcube_association_approval( - self, benchmark_uid: int, mlcube_uid: int, status: str - ): - """Approves an mlcube association - - Args: - mlcube_uid (int): Dataset UID - benchmark_uid (int): Benchmark UID - status (str): Approval status to set for the association + Returns: + List[dict]: List containing all associations information """ - url = f"{self.server_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" - res = self.__set_approval_status(url, status) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not approve association between mlcube {mlcube_uid} and benchmark {benchmark_uid}: {details}" - ) + url = f"{self.server_url}/me/datasets/training_associations/" + error_msg = "Could not retrieve user datasets training associations" + return self.__get_list(url, error_msg=error_msg) - def get_datasets_associations(self) -> List[dict]: - """Get all dataset associations related to the current user + def get_user_training_aggregators_associations(self) -> List[dict]: + """Get all aggregator associations related to the current user Returns: List[dict]: List containing all associations information """ - assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/") - return filter_latest_associations(assocs, "dataset") + url = f"{self.server_url}/me/aggregators/training_associations/" + error_msg = "Could not retrieve user aggregators training associations" + return self.__get_list(url, error_msg=error_msg) - def get_cubes_associations(self) -> List[dict]: - """Get all cube associations related to the current user + def get_user_training_cas_associations(self) -> List[dict]: + """Get all ca associations related to the current user Returns: List[dict]: List containing all associations information """ - assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/") - return filter_latest_associations(assocs, "model_mlcube") + url = f"{self.server_url}/me/cas/training_associations/" + error_msg = "Could not retrieve user cas training associations" + return self.__get_list(url, error_msg=error_msg) - def set_mlcube_association_priority( - self, benchmark_uid: int, mlcube_uid: int, priority: int - ): - """Sets the priority of an mlcube-benchmark association + # upload + def upload_benchmark(self, benchmark_dict: dict) -> int: + """Uploads a new benchmark to the server. Args: - mlcube_uid (int): MLCube UID - benchmark_uid (int): Benchmark UID - priority (int): priority value to set for the association - """ - url = f"{self.server_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" - data = {"priority": priority} - res = self.__auth_put(url, json=data) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not set the priority of mlcube {mlcube_uid} within the benchmark {benchmark_uid}: {details}" - ) + benchmark_dict (dict): benchmark_data to be uploaded - def update_dataset(self, dataset_id: int, data: dict): - url = f"{self.server_url}/datasets/{dataset_id}/" - res = self.__auth_put(url, json=data) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError(f"Could not update dataset: {details}") - return res.json() + Returns: + int: UID of newly created benchmark + """ + url = f"{self.server_url}/benchmarks/" + error_msg = "could not upload benchmark" + return self.__post(url, json=benchmark_dict, error_msg=error_msg) - def get_mlcube_datasets(self, mlcube_id: int) -> dict: - """Retrieves all datasets that have the specified mlcube as the prep mlcube + def upload_mlcube(self, mlcube_body: dict) -> int: + """Uploads an MLCube instance to the platform Args: - mlcube_id (int): mlcube ID to retrieve datasets from + mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes Returns: - dict: dictionary with the contents of each dataset + int: id of the created mlcube instance on the platform """ + url = f"{self.server_url}/mlcubes/" + error_msg = "could not upload mlcube" + return self.__post(url, json=mlcube_body, error_msg=error_msg) - datasets = self.__get_list(f"{self.server_url}/mlcubes/{mlcube_id}/datasets/") - return datasets - - def upload_training_exp(self, training_exp_dict: dict) -> int: - """Uploads a new training_exp to the server. + def upload_dataset(self, reg_dict: dict) -> int: + """Uploads registration data to the server, under the sha name of the file. Args: - benchmark_dict (dict): benchmark_data to be uploaded + reg_dict (dict): Dictionary containing registration information. Returns: - int: UID of newly created benchmark + int: id of the created dataset registration. """ - res = self.__auth_post(f"{self.server_url}/training/", json=training_exp_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"Could not upload training exp: {details}" - ) - return res.json() + url = f"{self.server_url}/datasets/" + error_msg = "could not upload dataset" + return self.__post(url, json=reg_dict, error_msg=error_msg) - def get_training_exp(self, training_exp_id: int) -> dict: - """Retrieves the training_exp specification file from the server + def upload_result(self, results_dict: dict) -> int: + """Uploads result to the server. Args: - benchmark_uid (int): uid for the desired benchmark + results_dict (dict): Dictionary containing results information. Returns: - dict: benchmark specification + dicr: generated results entry """ - res = self.__auth_get(f"{self.server_url}/training/{training_exp_id}") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"the specified training_exp doesn't exist: {details}" - ) - return res.json() + url = f"{self.server_url}/results/" + error_msg = "could not upload result" + return self.__post(url, json=results_dict, error_msg=error_msg) - def get_experiment_datasets(self, training_exp_id: int) -> dict: - """Retrieves all approved datasets for a given training_exp + def upload_training_exp(self, training_exp_dict: dict) -> int: + """Uploads a new training_exp to the server. Args: - benchmark_id (int): benchmark ID to retrieve results from + training_exp_dict (dict): training_exp to be uploaded Returns: - dict: dictionary with the contents of each result in the specified benchmark + dict: newly created training_exp """ - results = self.__get_list( - f"{self.server_url}/training/{training_exp_id}/datasets" - ) - results = [dataset["id"] for dataset in results] + url = f"{self.server_url}/training/" + error_msg = "could not upload training experiment" + return self.__post(url, json=training_exp_dict, error_msg=error_msg) - return results - - def get_experiment_aggregator(self, training_exp_id: int) -> dict: - """Retrieves the experiment aggregator + def upload_aggregator(self, aggregator_dict: dict) -> int: + """Uploads a new aggregator to the server. Args: - benchmark_id (int): benchmark ID to retrieve results from + benchmark_dict (dict): benchmark_data to be uploaded Returns: - dict: dictionary with the contents of each result in the specified benchmark + int: UID of newly created benchmark """ + url = f"{self.server_url}/aggregators/" + error_msg = "could not upload aggregator" + return self.__post(url, json=aggregator_dict, error_msg=error_msg) - res = self.__auth_get( - f"{self.server_url}/training/{training_exp_id}/aggregator" - ) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"There was a problem when retrieving the aggregator: {details}" - ) - return res.json() - - def set_experiment_as_operational(self, training_exp_id: int) -> dict: - """lock experiment (set as operational) + def upload_ca(self, ca_dict: dict) -> int: + """Uploads a new ca to the server. Args: - benchmark_id (int): benchmark ID to retrieve results from + benchmark_dict (dict): benchmark_data to be uploaded Returns: - dict: dictionary with the contents of each result in the specified benchmark + int: UID of newly created benchmark """ + url = f"{self.server_url}/cas/" + error_msg = "could not upload ca" + return self.__post(url, json=ca_dict, error_msg=error_msg) - url = f"{self.server_url}/training/{training_exp_id}/" - res = self.__set_state(url, "OPERATION") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not set operational state for experiment {training_exp_id}: {details}" - ) - - def upload_aggregator(self, aggregator_dict: dict) -> int: - """Uploads a new aggregator to the server. + def upload_training_event(self, trainnig_event_dict: dict) -> int: + """Uploads a new training event to the server. Args: benchmark_dict (dict): benchmark_data to be uploaded @@ -647,213 +619,249 @@ def upload_aggregator(self, aggregator_dict: dict) -> int: Returns: int: UID of newly created benchmark """ - res = self.__auth_post(f"{self.server_url}/aggregators/", json=aggregator_dict) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError(f"Could not upload aggregator: {details}") - return res.json() + url = f"{self.server_url}/training/events/" + error_msg = "could not upload training event" + return self.__post(url, json=trainnig_event_dict, error_msg=error_msg) - def get_aggregator(self, aggregator_id: int) -> dict: - """Retrieves the aggregator specification file from the server + # Association creation + def associate_benchmark_dataset( + self, data_uid: int, benchmark_uid: int, metadata: dict = {} + ): + """Create a Dataset Benchmark association Args: - benchmark_uid (int): uid for the desired benchmark + data_uid (int): Registered dataset UID + benchmark_uid (int): Benchmark UID + metadata (dict, optional): Additional metadata. Defaults to {}. + """ + url = f"{self.server_url}/datasets/benchmarks/" + data = { + "dataset": data_uid, + "benchmark": benchmark_uid, + "approval_status": Status.PENDING.value, + "metadata": metadata, + } + error_msg = "Could not associate dataset to benchmark" + return self.__post(url, json=data, error_msg=error_msg) - Returns: - dict: benchmark specification + def associate_benchmark_model( + self, cube_uid: int, benchmark_uid: int, metadata: dict = {} + ): + """Create an MLCube-Benchmark association + + Args: + cube_uid (int): MLCube UID + benchmark_uid (int): Benchmark UID + metadata (dict, optional): Additional metadata. Defaults to {}. """ - res = self.__auth_get(f"{self.server_url}/aggregators/{aggregator_id}") - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"the specified aggregator doesn't exist: {details}" - ) - return res.json() + url = f"{self.server_url}/mlcubes/benchmarks/" + data = { + "approval_status": Status.PENDING.value, + "model_mlcube": cube_uid, + "benchmark": benchmark_uid, + "metadata": metadata, + } + error_msg = "Could not associate mlcube to benchmark" + return self.__post(url, json=data, error_msg=error_msg) - def associate_aggregator(self, aggregator_id: int, training_exp_id: int, csr: str): - """Create a aggregator experiment association + def associate_training_dataset(self, data_uid: int, training_exp_id: int): + """Create a Dataset experiment association Args: data_uid (int): Registered dataset UID benchmark_uid (int): Benchmark UID metadata (dict, optional): Additional metadata. Defaults to {}. """ + url = f"{self.server_url}/datasets/training/" + data = { + "dataset": data_uid, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + } + error_msg = "Could not associate dataset to training_exp" + return self.__post(url, json=data, error_msg=error_msg) + + def associate_training_aggregator(self, aggregator_id: int, training_exp_id: int): + """Create a aggregator experiment association + + Args: + aggregator_id (int): Registered aggregator UID + training_exp_id (int): training experiment UID + """ + url = f"{self.server_url}/aggregators/training/" data = { "aggregator": aggregator_id, "training_exp": training_exp_id, "approval_status": Status.PENDING.value, - "signing_request": csr, } - res = self.__auth_post( - f"{self.server_url}/aggregators/training_experiments/", json=data - ) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not associate aggregator to training_exp: {details}" - ) + error_msg = "Could not associate aggregator to training_exp" + return self.__post(url, json=data, error_msg=error_msg) + + def associate_training_ca(self, ca_id: int, training_exp_id: int): + """Create a ca experiment association + + Args: + ca_id (int): Registered ca UID + training_exp_id (int): training experiment UID + """ + url = f"{self.server_url}/cas/training/" + data = { + "aggregator": ca_id, + "training_exp": training_exp_id, + "approval_status": Status.PENDING.value, + } + error_msg = "Could not associate ca to training_exp" + return self.__post(url, json=data, error_msg=error_msg) - def set_aggregator_association_approval( - self, training_exp_id: int, aggregator_id: int, status: str + # updates associations + def update_benchmark_dataset_association( + self, benchmark_uid: int, dataset_uid: int, data: str ): - """Approves a aggregator association + """Approves a dataset association Args: dataset_uid (int): Dataset UID benchmark_uid (int): Benchmark UID status (str): Approval status to set for the association """ - url = f"{self.server_url}/aggregators/{aggregator_id}/training_experiments/{training_exp_id}/" - res = self.__set_approval_status(url, status) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - "Could not approve association between aggregator" - f"{aggregator_id} and training_exp {training_exp_id}: {details}" - ) + url = f"{self.server_url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" + error_msg = f"Could not update association: dataset {dataset_uid}, benchmark {benchmark_uid}" + self.__put(url, json=data, error_msg=error_msg) - def get_aggregator_association( - self, training_exp_id: int, aggregator_id: int - ) -> dict: - """Retrieves the aggregator association specification file from the server + def update_benchmark_model_association( + self, benchmark_uid: int, mlcube_uid: int, data: dict + ): + """Approves an mlcube association Args: - benchmark_uid (int): uid for the desired benchmark - - Returns: - dict: benchmark specification + mlcube_uid (int): Dataset UID + benchmark_uid (int): Benchmark UID + status (str): Approval status to set for the association """ - url = f"{self.server_url}/aggregators/{aggregator_id}/training_experiments/{training_exp_id}/" - res = self.__auth_get(url) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"There was a problem when retrieving the association: {details}" - ) - return res.json() + url = f"{self.server_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" + error_msg = ( + f"Could update association: mlcube {mlcube_uid}, benchmark {benchmark_uid}" + ) + self.__put(url, json=data, error_msg=error_msg) - def associate_training_dset(self, data_uid: int, training_exp_id: int, csr: str): - """Create a Dataset experiment association + def update_training_aggregator_association( + self, training_exp_id: int, aggregator_id: int, data: dict + ): + """Approves a aggregator association Args: - data_uid (int): Registered dataset UID + dataset_uid (int): Dataset UID benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. + status (str): Approval status to set for the association """ - data = { - "dataset": data_uid, - "training_exp": training_exp_id, - "approval_status": Status.PENDING.value, - "signing_request": csr, - } - res = self.__auth_post( - f"{self.server_url}/datasets/training_experiments/", json=data + url = ( + f"{self.server_url}/aggregators/{aggregator_id}/training/{training_exp_id}/" ) - if res.status_code != 201: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - f"Could not associate dataset to training_exp: {details}" - ) + error_msg = ( + "Could not update association: aggregator" + f" {aggregator_id}, training_exp {training_exp_id}" + ) + self.__put(url, json=data, error_msg=error_msg) - def set_training_dataset_association_approval( - self, training_exp_id: int, dataset_uid: int, status: str + def update_training_dataset_association( + self, training_exp_id: int, dataset_uid: int, data: dict ): - """Approves a trainining dataset association + """Approves a training dataset association Args: dataset_uid (int): Dataset UID benchmark_uid (int): Benchmark UID status (str): Approval status to set for the association """ - url = f"{self.server_url}/datasets/{dataset_uid}/training_experiments/{training_exp_id}/" - res = self.__set_approval_status(url, status) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRequestError( - "Could not approve association between dataset" - f"{dataset_uid} and training_exp {training_exp_id}: {details}" - ) + url = f"{self.server_url}/datasets/{dataset_uid}/training/{training_exp_id}/" + error_msg = ( + "Could not approve association: dataset" + f"{dataset_uid}, training_exp {training_exp_id}" + ) + self.__put(url, json=data, error_msg=error_msg) - def get_training_dataset_association( - self, training_exp_id: int, dataset_uid: int - ) -> dict: - """Retrieves the training dataset association specification file from the server + def update_training_ca_association( + self, training_exp_id: int, ca_uid: int, data: dict + ): + """Approves a training ca association Args: - benchmark_uid (int): uid for the desired benchmark - - Returns: - dict: benchmark specification + dataset_uid (int): Dataset UID + benchmark_uid (int): Benchmark UID + status (str): Approval status to set for the association """ - url = f"{self.server_url}/datasets/{dataset_uid}/training_experiments/{training_exp_id}/" - res = self.__auth_get(url) - if res.status_code != 200: - log_response_error(res) - details = format_errors_dict(res.json()) - raise CommunicationRetrievalError( - f"There was a problem when retrieving the association: {details}" - ) - return res.json() + url = f"{self.server_url}/cas/{ca_uid}/training/{training_exp_id}/" + error_msg = ( + "Could not update association: ca" + f"{ca_uid}, training_exp {training_exp_id}" + ) + self.__put(url, json=data, error_msg=error_msg) - def get_aggregators(self) -> List[dict]: - """Retrieves all aggregators + # update objects + def update_dataset(self, dataset_id: int, data: dict): + url = f"{self.server_url}/datasets/{dataset_id}/" + error_msg = "Could not update dataset" + return self.__put(url, json=data, error_msg=error_msg) - Returns: - List[dict]: List of aggregators - """ - aggregators = self.__get_list(f"{self.server_url}/aggregators") - return aggregators + def update_training_exp(self, training_exp_id: int, data: dict): + url = f"{self.server_url}/training/{training_exp_id}/" + error_msg = "Could not update training experiment" + return self.__put(url, json=data, error_msg=error_msg) - def get_user_aggregators(self) -> dict: - """Retrieves all aggregators registered by the user + def update_training_event(self, training_event_id: int, data: dict): + url = f"{self.server_url}/training/events/{training_event_id}/" + error_msg = "Could not update training event" + return self.__put(url, json=data, error_msg=error_msg) - Returns: - dict: dictionary with the contents of each result registration query - """ - aggregators = self.__get_list(f"{self.server_url}/me/aggregators/") - return aggregators + # misc + def get_benchmark_results(self, benchmark_id: int) -> dict: + """Retrieves all results for a given benchmark - def get_training_exps(self) -> List[dict]: - """Retrieves all training_exps + Args: + benchmark_id (int): benchmark ID to retrieve results from Returns: - List[dict]: List of training_exps + dict: dictionary with the contents of each result in the specified benchmark """ - training_exps = self.__get_list(f"{self.server_url}/training") - return training_exps + url = f"{self.server_url}/benchmarks/{benchmark_id}/results/" + error_msg = "Could not get benchmark results" + return self.__get_list(url, error_msg=error_msg) - def get_user_training_exps(self) -> dict: - """Retrieves all training_exps registered by the user + def get_mlcube_datasets(self, mlcube_id: int) -> dict: + """Retrieves all datasets that have the specified mlcube as the prep mlcube + + Args: + mlcube_id (int): mlcube ID to retrieve datasets from Returns: - dict: dictionary with the contents of each result registration query + dict: dictionary with the contents of each dataset """ - training_exps = self.__get_list(f"{self.server_url}/me/training/") - return training_exps + url = f"{self.server_url}/mlcubes/{mlcube_id}/datasets/" + error_msg = "Could not get mlcube datasets" + return self.__get_list(url, error_msg=error_msg) - def get_training_datasets_associations(self) -> List[dict]: - """Get all training dataset associations related to the current user + def get_training_datasets_associations(self, training_exp_id: int) -> dict: + """Retrieves all approved datasets for a given training_exp + + Args: + benchmark_id (int): benchmark ID to retrieve results from Returns: - List[dict]: List containing all associations information + dict: dictionary with the contents of each result in the specified benchmark """ - assocs = self.__get_list( - f"{self.server_url}/me/datasets/training_associations/" - ) - return assocs + url = f"{self.server_url}/training/{training_exp_id}/datasets" + error_msg = "Could not get training experiment datasets associations" + return self.__get_list(url, error_msg=error_msg) - def get_aggregators_associations(self) -> List[dict]: - """Get all aggregator associations related to the current user + def get_benchmark_models_associations(self, benchmark_uid: int) -> List[int]: + """Retrieves all the model associations of a benchmark. + + Args: + benchmark_uid (int): UID of the desired benchmark Returns: - List[dict]: List containing all associations information + list[int]: List of benchmark model associations """ - assocs = self.__get_list(f"{self.server_url}/me/aggregators/associations/") - return assocs + url = f"{self.server_url}/benchmarks/{benchmark_uid}/models" + error_msg = "Could not get benchmark models associations" + return self.__get_list(url, error_msg=error_msg) diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 08c6656bc..b870959a2 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -48,6 +48,7 @@ auth_jwks_file = str(config_storage / ".jwks") creds_folder = str(config_storage / ".tokens") tokens_db = str(config_storage / ".tokens_db") +pki_assets = str(config_storage / ".pki_assets") images_folder = ".images" trash_folder = ".trash" @@ -144,6 +145,7 @@ benchmarks_filename = "benchmark.yaml" test_report_file = "test_report.yaml" reg_file = "registration-info.yaml" +agg_file = "agg-info.yaml" cube_metadata_filename = "mlcube-meta.yaml" log_file = "medperf.log" log_package_file = "medperf_logs.tar.gz" @@ -151,11 +153,14 @@ demo_dset_paths_file = "paths.yaml" mlcube_cache_file = ".cache_metadata.yaml" training_exps_filename = "training-info.yaml" -training_exp_cols_filename = "cols.yaml" -agg_cert_folder = "agg_cert" -data_cert_folder = "data_cert" +participants_list_filename = "cols.yaml" +training_exp_plan_filename = "plan.yaml" +training_report_file = "report.yaml" +training_out_logs = "logs" +training_out_weights = "weights" ca_cert_folder = "ca_cert" -network_config_filename = "network.yaml" +ca_config_file = "ca_config.json" +agg_config_file = "aggregator_config.yaml" report_file = "report.yaml" metadata_folder = "metadata" statistics_filename = "statistics.yaml" diff --git a/cli/medperf/cryptography/__init__.py b/cli/medperf/cryptography/__init__.py deleted file mode 100644 index b3f394d12..000000000 --- a/cli/medperf/cryptography/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -"""openfl.cryptography package.""" diff --git a/cli/medperf/cryptography/ca.py b/cli/medperf/cryptography/ca.py deleted file mode 100644 index d651919a4..000000000 --- a/cli/medperf/cryptography/ca.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Cryptography CA utilities.""" - -import datetime -import uuid -from typing import Tuple - -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.x509.base import Certificate -from cryptography.x509.base import CertificateSigningRequest -from cryptography.x509.extensions import ExtensionNotFound -from cryptography.x509.name import Name -from cryptography.x509.oid import ExtensionOID -from cryptography.x509.oid import NameOID - - -def generate_root_cert( - common_name: str = "Simple Root CA", days_to_expiration: int = 365 -) -> Tuple[RSAPrivateKey, Certificate]: - """Generate_root_certificate.""" - now = datetime.datetime.utcnow() - expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) - - # Generate private key - root_private_key = rsa.generate_private_key( - public_exponent=65537, key_size=3072, backend=default_backend() - ) - - # Generate public key - root_public_key = root_private_key.public_key() - builder = x509.CertificateBuilder() - subject = x509.Name( - [ - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), - x509.NameAttribute(NameOID.COMMON_NAME, common_name), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Root CA"), - ] - ) - issuer = subject - builder = builder.subject_name(subject) - builder = builder.issuer_name(issuer) - - builder = builder.not_valid_before(now) - builder = builder.not_valid_after(now + expiration_delta) - builder = builder.serial_number(int(uuid.uuid4())) - builder = builder.public_key(root_public_key) - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - - # Sign the CSR - certificate = builder.sign( - private_key=root_private_key, - algorithm=hashes.SHA384(), - backend=default_backend(), - ) - - return root_private_key, certificate - - -def generate_signing_csr() -> Tuple[RSAPrivateKey, CertificateSigningRequest]: - """Generate signing CSR.""" - # Generate private key - signing_private_key = rsa.generate_private_key( - public_exponent=65537, key_size=3072, backend=default_backend() - ) - - builder = x509.CertificateSigningRequestBuilder() - subject = x509.Name( - [ - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "org"), - x509.NameAttribute(NameOID.DOMAIN_COMPONENT, "simple"), - x509.NameAttribute(NameOID.COMMON_NAME, "Simple Signing CA"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Simple Inc"), - x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Simple Signing CA"), - ] - ) - builder = builder.subject_name(subject) - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - - # Sign the CSR - csr = builder.sign( - private_key=signing_private_key, - algorithm=hashes.SHA384(), - backend=default_backend(), - ) - - return signing_private_key, csr - - -def sign_certificate( - csr: CertificateSigningRequest, - issuer_private_key: RSAPrivateKey, - issuer_name: Name, - days_to_expiration: int = 365, - ca: bool = False, -) -> Certificate: - """ - Sign the incoming CSR request. - - Args: - csr : Certificate Signing Request object - issuer_private_key : Root CA private key if the request is for the signing - CA; Signing CA private key otherwise - issuer_name : x509 Name - days_to_expiration : int (365 days by default) - ca : Is this a certificate authority - """ - now = datetime.datetime.utcnow() - expiration_delta = days_to_expiration * datetime.timedelta(1, 0, 0) - - builder = x509.CertificateBuilder() - builder = builder.subject_name(csr.subject) - builder = builder.issuer_name(issuer_name) - builder = builder.not_valid_before(now) - builder = builder.not_valid_after(now + expiration_delta) - builder = builder.serial_number(int(uuid.uuid4())) - builder = builder.public_key(csr.public_key()) - builder = builder.add_extension( - x509.BasicConstraints(ca=ca, path_length=None), - critical=True, - ) - try: - builder = builder.add_extension( - csr.extensions.get_extension_for_oid( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME - ).value, - critical=False, - ) - except ExtensionNotFound: - pass # Might not have alternative name - - signed_cert = builder.sign( - private_key=issuer_private_key, - algorithm=hashes.SHA384(), - backend=default_backend(), - ) - return signed_cert diff --git a/cli/medperf/cryptography/io.py b/cli/medperf/cryptography/io.py deleted file mode 100644 index 52bfc5e95..000000000 --- a/cli/medperf/cryptography/io.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Cryptography IO utilities.""" - -import os -from hashlib import sha384 -from pathlib import Path -from typing import Tuple - -from cryptography import x509 -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.hazmat.primitives.serialization import load_pem_private_key -from cryptography.x509.base import Certificate -from cryptography.x509.base import CertificateSigningRequest - - -def read_key(path: Path) -> RSAPrivateKey: - """ - Read private key. - - Args: - path : Path (pathlib) - - Returns: - private_key - """ - with open(path, 'rb') as f: - pem_data = f.read() - - signing_key = load_pem_private_key(pem_data, password=None) - # TODO: replace assert with exception / sys.exit - assert (isinstance(signing_key, rsa.RSAPrivateKey)) - return signing_key - - -def write_key(key: RSAPrivateKey, path: Path) -> None: - """ - Write private key. - - Args: - key : RSA private key object - path : Path (pathlib) - - """ - def key_opener(path, flags): - return os.open(path, flags, mode=0o600) - - with open(path, 'wb', opener=key_opener) as f: - f.write(key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() - )) - - -def read_crt(path: Path) -> Certificate: - """ - Read signed TLS certificate. - - Args: - path : Path (pathlib) - - Returns: - Cryptography TLS Certificate object - """ - with open(path, 'rb') as f: - pem_data = f.read() - - certificate = x509.load_pem_x509_certificate(pem_data) - # TODO: replace assert with exception / sys.exit - assert (isinstance(certificate, x509.Certificate)) - return certificate - - -def write_crt(certificate: Certificate, path: Path) -> None: - """ - Write cryptography certificate / csr. - - Args: - certificate : cryptography csr / certificate object - path : Path (pathlib) - - Returns: - Cryptography TLS Certificate object - """ - with open(path, 'wb') as f: - f.write(certificate.public_bytes( - encoding=serialization.Encoding.PEM, - )) - - -def read_csr(path: Path) -> Tuple[CertificateSigningRequest, str]: - """ - Read certificate signing request. - - Args: - path : Path (pathlib) - - Returns: - Cryptography CSR object - """ - with open(path, 'rb') as f: - pem_data = f.read() - - csr = x509.load_pem_x509_csr(pem_data) - # TODO: replace assert with exception / sys.exit - assert (isinstance(csr, x509.CertificateSigningRequest)) - return csr, get_csr_hash(csr) - - -def get_csr_hash(certificate: CertificateSigningRequest) -> str: - """ - Get hash of cryptography certificate. - - Args: - certificate : Cryptography CSR object - - Returns: - Hash of cryptography certificate / csr - """ - hasher = sha384() - encoded_bytes = certificate.public_bytes( - encoding=serialization.Encoding.PEM, - ) - hasher.update(encoded_bytes) - return hasher.hexdigest() diff --git a/cli/medperf/cryptography/participant.py b/cli/medperf/cryptography/participant.py deleted file mode 100644 index d6e94712b..000000000 --- a/cli/medperf/cryptography/participant.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (C) 2020-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Cryptography participant utilities.""" -from typing import Tuple - -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from cryptography.x509.base import CertificateSigningRequest -from cryptography.x509.oid import NameOID - - -def generate_csr(common_name: str, - server: bool = False) -> Tuple[RSAPrivateKey, CertificateSigningRequest]: - """Issue certificate signing request for server and client.""" - # Generate private key - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=3072, - backend=default_backend() - ) - - builder = x509.CertificateSigningRequestBuilder() - subject = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, common_name), - ]) - builder = builder.subject_name(subject) - builder = builder.add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=True, - ) - if server: - builder = builder.add_extension( - x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.SERVER_AUTH]), - critical=True - ) - - else: - builder = builder.add_extension( - x509.ExtendedKeyUsage([x509.ExtendedKeyUsageOID.CLIENT_AUTH]), - critical=True - ) - - builder = builder.add_extension( - x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - data_encipherment=False, - key_agreement=False, - content_commitment=False, - key_cert_sign=False, - crl_sign=False, - encipher_only=False, - decipher_only=False - ), - critical=True - ) - - builder = builder.add_extension( - x509.SubjectAlternativeName([x509.DNSName(common_name)]), - critical=False - ) - - # Sign the CSR - csr = builder.sign( - private_key=private_key, algorithm=hashes.SHA384(), - backend=default_backend() - ) - - return private_key, csr diff --git a/cli/medperf/cryptography/utils.py b/cli/medperf/cryptography/utils.py deleted file mode 100644 index 03f9eb940..000000000 --- a/cli/medperf/cryptography/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from cryptography.hazmat.primitives import serialization -from cryptography import x509 - - -def cert_to_str(cert): - return cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") - - -def str_to_cert(cert_str): - return x509.load_pem_x509_certificate(cert_str.encode("utf-8")) - - -def str_to_csr(csr_str): - return x509.load_pem_x509_csr(csr_str.encode("utf-8")) diff --git a/cli/medperf/entities/aggregator.py b/cli/medperf/entities/aggregator.py index c011b0865..47187df8b 100644 --- a/cli/medperf/entities/aggregator.py +++ b/cli/medperf/entities/aggregator.py @@ -1,21 +1,15 @@ import os -import yaml -import logging -import hashlib -from typing import List, Optional, Union +from pydantic import validator -from medperf.entities.interface import Entity, Uploadable +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema -from medperf.exceptions import ( - InvalidArgumentError, - MedperfException, - CommunicationRetrievalError, -) + import medperf.config as config from medperf.account_management import get_medperf_user_data +import yaml -class Aggregator(Entity, MedperfSchema, Uploadable): +class Aggregator(Entity, MedperfSchema): """ Class representing a compatibility test report entry @@ -30,72 +24,60 @@ class Aggregator(Entity, MedperfSchema, Uploadable): - results """ - server_config: Optional[dict] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.address = self.server_config["address"] - self.port = self.server_config["port"] - self.generated_uid = self.__generate_uid() - - path = config.aggregators_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path - self.network_config_path = os.path.join(path, config.network_config_filename) + metadata: dict = {} + config: dict + aggregation_mlcube: int - def __generate_uid(self): - """A helper that generates a unique hash for a server config.""" + @staticmethod + def get_type(): + return "aggregator" - params = str(self.server_config) - return hashlib.sha1(params.encode()).hexdigest() + @staticmethod + def get_storage_path(): + return config.aggregators_folder - def todict(self): - return self.extended_dict() + @staticmethod + def get_comms_retriever(): + return config.comms.get_aggregator - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Aggregator"]: - """Gets and creates instances of all the locally prepared aggregators + @staticmethod + def get_metadata_filename(): + return config.agg_file - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + @staticmethod + def get_comms_uploader(): + return config.comms.upload_aggregator - Returns: - List[Aggregator]: a list of Aggregator instances. - """ - logging.info("Retrieving all aggregators") - aggs = [] - if not local_only: - aggs = cls.__remote_all(filters=filters) + @validator("config", pre=True, always=True) + def check_config(cls, v, *, values, **kwargs): + keys = set(v.keys()) + allowed_keys = { + "address", + "port", + } + if keys != allowed_keys: + raise ValueError("config must contain two keys only: address and port") + return v - remote_uids = set([agg.id for agg in aggs]) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - local_aggs = cls.__local_all() + self.generated_uid = self.name - aggs += [agg for agg in local_aggs if agg.id not in remote_uids] + self.address = self.config["address"] + self.port = self.config["port"] - return aggs + self.config_path = os.path.join(self.path, config.agg_config_file) @classmethod - def __remote_all(cls, filters: dict) -> List["Aggregator"]: - aggs = [] - try: - comms_fn = cls.__remote_prefilter(filters) - aggs_meta = comms_fn() - aggs = [cls(**meta) for meta in aggs_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all aggregators from the server" - logging.warning(msg) - - return aggs + def from_experiment(cls, training_exp_uid: int) -> "Aggregator": + meta = config.comms.get_experiment_aggregator(training_exp_uid) + agg = cls(**meta) + agg.write() + return agg @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -109,125 +91,18 @@ def __remote_prefilter(cls, filters: dict) -> callable: comms_fn = config.comms.get_user_aggregators return comms_fn - @classmethod - def __local_all(cls) -> List["Aggregator"]: - aggs = [] - aggregator_storage = config.aggregators_folder - try: - uids = next(os.walk(aggregator_storage))[1] - except StopIteration: - msg = "Couldn't iterate over the aggregator directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - agg = cls(**local_meta) - aggs.append(agg) - - return aggs - - @classmethod - def get(cls, agg_uid: Union[str, int], local_only: bool = False) -> "Aggregator": - """Retrieves and creates a Aggregator instance from the comms instance. - If the aggregator is present in the user's machine then it retrieves it from there. - - Args: - agg_uid (str): server UID of the aggregator - - Returns: - Aggregator: Specified Aggregator Instance - """ - if not str(agg_uid).isdigit() or local_only: - return cls.__local_get(agg_uid) - - try: - return cls.__remote_get(agg_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Aggregator {agg_uid} from comms failed") - logging.info(f"Looking for aggregator {agg_uid} locally") - return cls.__local_get(agg_uid) - - @classmethod - def __remote_get(cls, agg_uid: int) -> "Aggregator": - """Retrieves and creates a Aggregator instance from the comms instance. - If the aggregator is present in the user's machine then it retrieves it from there. - - Args: - agg_uid (str): server UID of the aggregator - - Returns: - Aggregator: Specified Aggregator Instance - """ - logging.debug(f"Retrieving aggregator {agg_uid} remotely") - meta = config.comms.get_aggregator(agg_uid) - aggregator = cls(**meta) - aggregator.write() - return aggregator - - @classmethod - def __local_get(cls, agg_uid: Union[str, int]) -> "Aggregator": - """Retrieves and creates a Aggregator instance from the comms instance. - If the aggregator is present in the user's machine then it retrieves it from there. - - Args: - agg_uid (str): server UID of the aggregator - - Returns: - Aggregator: Specified Aggregator Instance - """ - logging.debug(f"Retrieving aggregator {agg_uid} locally") - local_meta = cls.__get_local_dict(agg_uid) - aggregator = cls(**local_meta) - return aggregator - - def write(self): - logging.info(f"Updating registration information for aggregator: {self.id}") - logging.debug(f"registration information: {self.todict()}") - regfile = os.path.join(self.path, config.reg_file) - os.makedirs(self.path, exist_ok=True) - with open(regfile, "w") as f: - yaml.dump(self.todict(), f) - - # write network config - with open(self.network_config_path, "w") as f: - yaml.dump(self.server_config, f) - - return regfile - - def upload(self): - """Uploads the registration information to the comms. - - Args: - comms (Comms): Instance of the comms interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test aggregators.") - aggregator_dict = self.todict() - updated_aggregator_dict = config.comms.upload_aggregator(aggregator_dict) - return updated_aggregator_dict - - @classmethod - def __get_local_dict(cls, aggregator_uid): - aggregator_path = os.path.join( - config.aggregators_folder, str(aggregator_uid) - ) - regfile = os.path.join(aggregator_path, config.reg_file) - if not os.path.exists(regfile): - raise InvalidArgumentError( - "The requested aggregator information could not be found locally" - ) - with open(regfile, "r") as f: - reg = yaml.safe_load(f) - return reg + def prepare_config(self): + with open(self.config_path, "w") as f: + yaml.dump(self.config, f) def display_dict(self): return { "UID": self.identifier, "Name": self.name, "Generated Hash": self.generated_uid, - "Address": self.server_config["address"], - "Port": self.server_config["port"], + "Address": self.address, + "MLCube": int(self.aggregation_mlcube), + "Port": self.port, "Created At": self.created_at, "Registered": self.is_registered, } diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 1d33efa95..9edd76a7e 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,4 +1,5 @@ from typing import List, Optional +from medperf.commands.association.utils import get_associations_list from pydantic import HttpUrl, Field import medperf.config as config @@ -86,12 +87,10 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: Returns: List[int]: List of mlcube uids """ - associations = config.comms.get_benchmark_model_associations(benchmark_uid) - models_uids = [ - assoc["model_mlcube"] - for assoc in associations - if assoc["approval_status"] == "APPROVED" - ] + associations = get_associations_list( + "benchmark", "model_mlcube", "APPROVED", experiment_id=benchmark_uid + ) + models_uids = [assoc["model_mlcube"] for assoc in associations] return models_uids def display_dict(self): diff --git a/cli/medperf/entities/ca.py b/cli/medperf/entities/ca.py new file mode 100644 index 000000000..7c945ac04 --- /dev/null +++ b/cli/medperf/entities/ca.py @@ -0,0 +1,115 @@ +import json +import os +from medperf.entities.interface import Entity +from medperf.entities.schemas import MedperfSchema +from pydantic import validator +import medperf.config as config +from medperf.account_management import get_medperf_user_data + + +class CA(Entity, MedperfSchema): + """ + Class representing a compatibility test report entry + + A test report is comprised of the components of a test execution: + - data used, which can be: + - a demo aggregator url and its hash, or + - a raw data path and its labels path, or + - a prepared aggregator uid + - Data preparation cube if the data used was not already prepared + - model cube + - evaluator cube + - results + """ + + metadata: dict = {} + client_mlcube: int + server_mlcube: int + ca_mlcube: int + config: dict + + @staticmethod + def get_type(): + return "ca" + + @staticmethod + def get_storage_path(): + return config.cas_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_ca + + @staticmethod + def get_metadata_filename(): + return config.ca_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_ca + + @validator("config", pre=True, always=True) + def check_config(cls, v, *, values, **kwargs): + keys = set(v.keys()) + allowed_keys = { + "address", + "port", + "fingerprint", + "client_provisioner", + "server_provisioner", + } + if keys != allowed_keys: + raise ValueError("config must contain two keys only: address and port") + return v + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + + self.address = self.config["address"] + self.port = self.config["port"] + self.fingerprint = self.config["fingerprint"] + self.client_provisioner = self.config["client_provisioner"] + self.server_provisioner = self.config["server_provisioner"] + + self.config_path = os.path.join(self.path, config.ca_config_file) + self.pki_assets = os.path.join(self.path, config.ca_cert_folder) + + @classmethod + def from_experiment(cls, training_exp_uid: int) -> "CA": + meta = config.comms.get_experiment_ca(training_exp_uid) + ca = cls(**meta) + ca.write() + return ca + + @classmethod + def _Entity__remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_cas + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_cas + return comms_fn + + def prepare_config(self): + with open(self.config_path, "w") as f: + json.dump(self.config, f) + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Generated Hash": self.generated_uid, + "Address": self.address, + "fingerprint": self.fingerprint, + "Port": self.port, + "Created At": self.created_at, + "Registered": self.is_registered, + } diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py new file mode 100644 index 000000000..484ae4767 --- /dev/null +++ b/cli/medperf/entities/event.py @@ -0,0 +1,100 @@ +from datetime import datetime +import os +from typing import Optional +from medperf.entities.interface import Entity +import medperf.config as config +from medperf.entities.schemas import MedperfSchema +from medperf.account_management import get_medperf_user_data +import yaml + + +class TrainingEvent(Entity, MedperfSchema): + """ + Class representing a compatibility test report entry + + A test report is comprised of the components of a test execution: + - data used, which can be: + - a demo aggregator url and its hash, or + - a raw data path and its labels path, or + - a prepared aggregator uid + - Data preparation cube if the data used was not already prepared + - model cube + - evaluator cube + - results + """ + + training_exp: int + participants: dict + finished: bool = False + finished_at: Optional[datetime] + report: dict = {} + + @staticmethod + def get_type(): + return "training event" + + @staticmethod + def get_storage_path(): + return config.training_events_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_training_event + + @staticmethod + def get_metadata_filename(): + return config.training_event_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_training_event + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generated_uid = self.name + self.participants_list_path = os.path.join( + self.path, config.participants_list_filename + ) + self.out_logs = os.path.join(self.path, config.training_out_logs) + self.out_weights = os.path.join(self.path, config.training_out_weights) + self.report_path = os.path.join(self.path, config.training_report_file) + + @classmethod + def from_experiment(cls, training_exp_uid: int) -> "TrainingEvent": + meta = config.comms.get_experiment_event(training_exp_uid) + ca = cls(**meta) + ca.write() + return ca + + @classmethod + def _Entity__remote_prefilter(cls, filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities + + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + comms_fn = config.comms.get_training_events + if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]: + comms_fn = config.comms.get_user_training_events + return comms_fn + + def prepare_participants_list(self): + with open(self.participants_list_path, "w") as f: + yaml.dump(self.participants, f) + + def display_dict(self): + return { + "UID": self.identifier, + "Name": self.name, + "Experiment": self.training_exp, + "Generated Hash": self.generated_uid, + "Participants": self.participants, + "Created At": self.created_at, + "Registered": self.is_registered, + "Finished": self.finished, + "Report": self.report, + } diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index 50f050361..f26a587f2 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -1,21 +1,16 @@ import os -from medperf.exceptions import MedperfException +from medperf.commands.association.utils import get_associations_list import yaml -import logging -from typing import List, Optional, Union -from pydantic import HttpUrl, Field, validator +from typing import List, Optional +from pydantic import HttpUrl, Field import medperf.config as config -from medperf.entities.interface import Entity, Uploadable -from medperf.utils import get_dataset_common_name -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError +from medperf.entities.interface import Entity from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class TrainingExp( - Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema -): +class TrainingExp(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): """ Class representing a TrainingExp @@ -28,22 +23,34 @@ class TrainingExp( description: Optional[str] = Field(None, max_length=20) docs_url: Optional[HttpUrl] - demo_dataset_tarball_url: Optional[str] - demo_dataset_tarball_hash: Optional[str] - demo_dataset_generated_uid: Optional[str] + demo_dataset_tarball_url: str + demo_dataset_tarball_hash: str + demo_dataset_generated_uid: str data_preparation_mlcube: int fl_mlcube: int - public_key: Optional[str] - datasets: List[int] = None + plan: dict = {} metadata: dict = {} user_metadata: dict = {} - @validator("datasets", pre=True, always=True) - def set_default_datasets_value(cls, value, values, **kwargs): - if not value: - # Empty or None value assigned - return [] - return value + @staticmethod + def get_type(): + return "training experiment" + + @staticmethod + def get_storage_path(): + return config.training_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_training_exp + + @staticmethod + def get_metadata_filename(): + return config.training_exps_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_training_exp def __init__(self, *args, **kwargs): """Creates a new training_exp instance @@ -54,64 +61,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_uid = self.name - path = config.training_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - self.path = path - self.cert_path = os.path.join(path, config.ca_cert_folder) - self.cols_path = os.path.join(path, config.training_exp_cols_filename) - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["TrainingExp"]: - """Gets and creates instances of all retrievable training_exps - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[TrainingExp]: a list of TrainingExp instances. - """ - logging.info("Retrieving all training_exps") - training_exps = [] - - if not local_only: - training_exps = cls.__remote_all(filters=filters) - - remote_uids = set([training_exp.id for training_exp in training_exps]) - - local_training_exps = cls.__local_all() - - training_exps += [ - training_exp - for training_exp in local_training_exps - if training_exp.id not in remote_uids - ] - - return training_exps - - @classmethod - def __remote_all(cls, filters: dict) -> List["TrainingExp"]: - training_exps = [] - try: - comms_fn = cls.__remote_prefilter(filters) - training_exps_meta = comms_fn() - for training_exp_meta in training_exps_meta: - # Loading all related models for all training_exps could be expensive. - # Most probably not necessary when getting all training_exps. - # If associated models for a training_exp are needed then use TrainingExp.get() - training_exp_meta["datasets"] = [] - training_exps = [cls(**meta) for meta in training_exps_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all training_exps from the server" - logging.warning(msg) - - return training_exps + self.plan_path = os.path.join(self.path, config.training_exp_plan_filename) @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + def _Entity__remote_prefilter(cls, filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -126,112 +79,24 @@ def __remote_prefilter(cls, filters: dict) -> callable: return comms_fn @classmethod - def __local_all(cls) -> List["TrainingExp"]: - training_exps = [] - training_exps_storage = config.training_folder - try: - uids = next(os.walk(training_exps_storage))[1] - except StopIteration: - msg = "Couldn't iterate over training_exps directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - training_exp = cls(**meta) - training_exps.append(training_exp) - - return training_exps - - @classmethod - def get( - cls, training_exp_uid: Union[str, int], local_only: bool = False - ) -> "TrainingExp": - """Retrieves and creates a TrainingExp instance from the server. - If training_exp already exists in the platform then retrieve that - version. - - Args: - training_exp_uid (str): UID of the training_exp. - comms (Comms): Instance of a communication interface. - - Returns: - TrainingExp: a TrainingExp instance with the retrieved data. - """ - - if not str(training_exp_uid).isdigit() or local_only: - return cls.__local_get(training_exp_uid) - - try: - return cls.__remote_get(training_exp_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting TrainingExp {training_exp_uid} from comms failed") - logging.info(f"Looking for training_exp {training_exp_uid} locally") - return cls.__local_get(training_exp_uid) - - @classmethod - def __remote_get(cls, training_exp_uid: int) -> "TrainingExp": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving training_exp {training_exp_uid} remotely") - training_exp_dict = config.comms.get_training_exp(training_exp_uid) - datasets = cls.get_datasets_uids(training_exp_uid) - training_exp_dict["datasets"] = datasets - training_exp = cls(**training_exp_dict) - training_exp.write() - return training_exp - - @classmethod - def __local_get(cls, training_exp_uid: Union[str, int]) -> "TrainingExp": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving training_exp {training_exp_uid} locally") - training_exp_dict = cls.__get_local_dict(training_exp_uid) - training_exp = cls(**training_exp_dict) - return training_exp - - @classmethod - def __get_local_dict(cls, training_exp_uid) -> dict: - """Retrieves a local training_exp information + def get_datasets_uids(cls, training_exp_uid: int) -> List[int]: + """Retrieves the list of models associated to the training_exp Args: - training_exp_uid (str): uid of the local training_exp + training_exp_uid (int): UID of the training_exp. + comms (Comms): Instance of the communications interface. Returns: - dict: information of the training_exp + List[int]: List of mlcube uids """ - logging.info(f"Retrieving training_exp {training_exp_uid} from local storage") - training_exp_storage = os.path.join( - config.training_folder, str(training_exp_uid) - ) - training_exp_file = os.path.join( - training_exp_storage, config.training_exps_filename + associations = get_associations_list( + "training_exp", "dataset", "APPROVED", experiment_id=training_exp_uid ) - if not os.path.exists(training_exp_file): - raise InvalidArgumentError( - "No training_exp with the given uid could be found" - ) - with open(training_exp_file, "r") as f: - data = yaml.safe_load(f) - - return data + datasets_uids = [assoc["dataset"] for assoc in associations] + return datasets_uids @classmethod - def get_datasets_uids(cls, training_exp_uid: int) -> List[int]: + def get_datasets_with_users(cls, training_exp_uid: int) -> List[int]: """Retrieves the list of models associated to the training_exp Args: @@ -241,65 +106,14 @@ def get_datasets_uids(cls, training_exp_uid: int) -> List[int]: Returns: List[int]: List of mlcube uids """ - return config.comms.get_experiment_datasets(training_exp_uid) - - def todict(self) -> dict: - """Dictionary representation of the training_exp instance - - Returns: - dict: Dictionary containing training_exp information - """ - return self.extended_dict() - - def write(self) -> str: - """Writes the training_exp into disk - - Args: - filename (str, optional): name of the file. Defaults to config.training_exps_filename. - - Returns: - str: path to the created training_exp file - """ - data = self.todict() - training_exp_file = os.path.join(self.path, config.training_exps_filename) - if not os.path.exists(training_exp_file): - os.makedirs(self.path, exist_ok=True) - with open(training_exp_file, "w") as f: - yaml.dump(data, f) - - # write cert - os.makedirs(self.cert_path, exist_ok=True) - cert_file = os.path.join(self.cert_path, "cert.crt") - with open(cert_file, "w") as f: - f.write(self.public_key) - - # write cols - dataset_owners_emails = [""] * len( - self.datasets - ) # TODO (this will need some work) - # our medperf's user info endpoint is not public - # emails currently are not stored in medperf (auth0 only. in access tokens as well) - cols = [ - get_dataset_common_name(email, dataset_id, self.id) - for email, dataset_id in zip(dataset_owners_emails, self.datasets) - ] - with open(self.cols_path, "w") as f: - f.write("\n".join(cols)) - - return training_exp_file - - def upload(self): - """Uploads a training_exp to the server + uids_with_users = config.comms.get_experiment_datasets_with_users( + training_exp_uid + ) + return uids_with_users - Args: - comms (Comms): communications entity to submit through - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test training_exps.") - body = self.todict() - updated_body = config.comms.upload_training_exp(body) - updated_body["datasets"] = body["datasets"] - return updated_body + def prepare_plan(self): + with open(self.plan_path, "w") as f: + yaml.dump(self.plan, f) def display_dict(self): return { @@ -309,7 +123,7 @@ def display_dict(self): "Documentation": self.docs_url, "Created At": self.created_at, "FL MLCube": int(self.fl_mlcube), - "Associated Datasets": ",".join(map(str, self.datasets)), + "Plan": self.plan, "State": self.state, "Registered": self.is_registered, "Approval Status": self.approval_status, diff --git a/cli/medperf/tests/commands/association/test_approve.py b/cli/medperf/tests/commands/association/test_approve.py index 23c50c721..490351490 100644 --- a/cli/medperf/tests/commands/association/test_approve.py +++ b/cli/medperf/tests/commands/association/test_approve.py @@ -25,7 +25,7 @@ def test_run_fails_if_invalid_arguments(mocker, comms, ui, dset_uid, mlcube_uid) @pytest.mark.parametrize("status", [Status.APPROVED, Status.REJECTED]) def test_run_calls_comms_dset_approval_with_status(mocker, comms, ui, dset_uid, status): # Arrange - spy = mocker.patch.object(comms, "set_dataset_association_approval") + spy = mocker.patch.object(comms, "update_benchmark_dataset_association") # Act Approval.run(1, status, dataset_uid=dset_uid) @@ -40,7 +40,7 @@ def test_run_calls_comms_mlcube_approval_with_status( mocker, comms, ui, mlcube_uid, status ): # Arrange - spy = mocker.patch.object(comms, "set_mlcube_association_approval") + spy = mocker.patch.object(comms, "update_benchmark_model_association") # Act Approval.run(1, status, mlcube_uid=mlcube_uid) diff --git a/cli/medperf/tests/commands/association/test_priority.py b/cli/medperf/tests/commands/association/test_priority.py index 8d7a70392..81b67e602 100644 --- a/cli/medperf/tests/commands/association/test_priority.py +++ b/cli/medperf/tests/commands/association/test_priority.py @@ -36,7 +36,7 @@ def setup_comms(mocker, comms, associations): ) mocker.patch.object( comms, - "set_mlcube_association_priority", + "update_benchmark_model_association", side_effect=set_priority_behavior(associations), ) diff --git a/cli/medperf/tests/commands/benchmark/test_associate.py b/cli/medperf/tests/commands/benchmark/test_associate.py index 461a968f8..92c5b704f 100644 --- a/cli/medperf/tests/commands/benchmark/test_associate.py +++ b/cli/medperf/tests/commands/benchmark/test_associate.py @@ -11,8 +11,8 @@ def test_run_fails_if_model_and_dset_passed(mocker, model_uid, data_uid, comms, ui): # Arrange num_arguments = int(data_uid is None) + int(model_uid is None) - mocker.patch.object(comms, "associate_cube") - mocker.patch.object(comms, "associate_dset") + mocker.patch.object(comms, "associate_benchmark_model") + mocker.patch.object(comms, "associate_benchmark_dataset") mocker.patch(PATCH_ASSOC.format("AssociateCube.run")) mocker.patch(PATCH_ASSOC.format("AssociateDataset.run")) diff --git a/cli/medperf/tests/commands/dataset/test_associate.py b/cli/medperf/tests/commands/dataset/test_associate.py index 647eb6a13..8f1b2c1e7 100644 --- a/cli/medperf/tests/commands/dataset/test_associate.py +++ b/cli/medperf/tests/commands/dataset/test_associate.py @@ -71,7 +71,7 @@ def test_associates_if_approved( ): # Arrange result = TestResult() - assoc_func = "associate_dset" + assoc_func = "associate_benchmark_dataset" mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=True) exec_ret = [result] mocker.patch(PATCH_ASSOC.format("BenchmarkExecution.run"), return_value=exec_ret) @@ -93,7 +93,7 @@ def test_stops_if_not_approved(mocker, comms, ui, dataset, benchmark): exec_ret = [result] mocker.patch(PATCH_ASSOC.format("BenchmarkExecution.run"), return_value=exec_ret) spy = mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=False) - assoc_spy = mocker.patch.object(comms, "associate_dset") + assoc_spy = mocker.patch.object(comms, "associate_benchmark_dataset") # Act AssociateDataset.run(1, 1) @@ -110,7 +110,7 @@ def test_associate_calls_allows_cache_by_default(mocker, comms, ui, dataset, ben result = TestResult() data_uid = 1562 benchmark_uid = 3557 - assoc_func = "associate_dset" + assoc_func = "associate_benchmark_dataset" mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=True) exec_ret = [result] spy = mocker.patch( diff --git a/cli/medperf/tests/commands/mlcube/test_associate.py b/cli/medperf/tests/commands/mlcube/test_associate.py index cf72ab574..ab8743f12 100644 --- a/cli/medperf/tests/commands/mlcube/test_associate.py +++ b/cli/medperf/tests/commands/mlcube/test_associate.py @@ -30,7 +30,7 @@ def test_run_associates_cube_with_comms( mocker, cube, benchmark, cube_uid, benchmark_uid, comms, ui ): # Arrange - spy = mocker.patch.object(comms, "associate_cube") + spy = mocker.patch.object(comms, "associate_benchmark_model") comp_ret = ("", {}) mocker.patch.object(ui, "prompt", return_value="y") mocker.patch( @@ -70,7 +70,7 @@ def test_stops_if_not_approved(mocker, comms, ui, cube, benchmark): PATCH_ASSOC.format("CompatibilityTestExecution.run"), return_value=comp_ret ) spy = mocker.patch(PATCH_ASSOC.format("approval_prompt"), return_value=False) - assoc_spy = mocker.patch.object(comms, "associate_cube") + assoc_spy = mocker.patch.object(comms, "associate_benchmark_model") # Act AssociateCube.run(1, 1) diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index fb3596c98..e17f99f48 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -52,7 +52,7 @@ def server(mocker, ui): {"json": {}}, ), ( - "associate_dset", + "associate_benchmark_dataset", "post", 201, [1, 1], @@ -115,7 +115,7 @@ def test_methods_run_authorized_method(mocker, server, method_params): ("get_cube_metadata", [1], {}, CommunicationRetrievalError), ("upload_dataset", [{}], {"id": "invalid id"}, CommunicationRequestError), ("upload_result", [{}], {"id": "invalid id"}, CommunicationRequestError), - ("associate_dset", [1, 1], {}, CommunicationRequestError), + ("associate_benchmark_dataset", [1, 1], {}, CommunicationRequestError), ], ) def test_methods_exit_if_status_not_200(mocker, server, status, method_params): @@ -462,7 +462,9 @@ def test_upload_results_returns_result_body(mocker, server, body): @pytest.mark.parametrize("cube_uid", [2156, 915]) @pytest.mark.parametrize("benchmark_uid", [1206, 3741]) -def test_associate_cube_posts_association_data(mocker, server, cube_uid, benchmark_uid): +def test_associate_benchmark_model_posts_association_data( + mocker, server, cube_uid, benchmark_uid +): # Arrange data = { "approval_status": Status.PENDING.value, @@ -474,7 +476,7 @@ def test_associate_cube_posts_association_data(mocker, server, cube_uid, benchma spy = mocker.patch(patch_server.format("REST._REST__auth_post"), return_value=res) # Act - server.associate_cube(cube_uid, benchmark_uid) + server.associate_benchmark_model(cube_uid, benchmark_uid) # Assert spy.assert_called_once_with(ANY, json=data) @@ -483,7 +485,7 @@ def test_associate_cube_posts_association_data(mocker, server, cube_uid, benchma @pytest.mark.parametrize("dataset_uid", [4417, 1057]) @pytest.mark.parametrize("benchmark_uid", [1011, 635]) @pytest.mark.parametrize("status", [Status.APPROVED.value, Status.REJECTED.value]) -def test_set_dataset_association_approval_sets_approval( +def test_update_benchmark_dataset_association_sets_approval( mocker, server, dataset_uid, benchmark_uid, status ): # Arrange @@ -494,7 +496,7 @@ def test_set_dataset_association_approval_sets_approval( exp_url = f"{full_url}/datasets/{dataset_uid}/benchmarks/{benchmark_uid}/" # Act - server.set_dataset_association_approval(benchmark_uid, dataset_uid, status) + server.update_benchmark_dataset_association(benchmark_uid, dataset_uid, status) # Assert spy.assert_called_once_with(exp_url, status) @@ -503,7 +505,7 @@ def test_set_dataset_association_approval_sets_approval( @pytest.mark.parametrize("mlcube_uid", [4596, 3530]) @pytest.mark.parametrize("benchmark_uid", [3966, 4188]) @pytest.mark.parametrize("status", [Status.APPROVED.value, Status.REJECTED.value]) -def test_set_mlcube_association_approval_sets_approval( +def test_update_benchmark_model_association_sets_approval( mocker, server, mlcube_uid, benchmark_uid, status ): # Arrange @@ -514,7 +516,7 @@ def test_set_mlcube_association_approval_sets_approval( exp_url = f"{full_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" # Act - server.set_mlcube_association_approval(benchmark_uid, mlcube_uid, status) + server.update_benchmark_model_association(benchmark_uid, mlcube_uid, status) # Assert spy.assert_called_once_with(exp_url, status) @@ -576,7 +578,7 @@ def test_upload_benchmark_returns_benchmark_body(mocker, server, body): @pytest.mark.parametrize("mlcube_uid", [4596, 3530]) @pytest.mark.parametrize("benchmark_uid", [3966, 4188]) @pytest.mark.parametrize("priority", [2, -10]) -def test_set_mlcube_association_priority_sets_priority( +def test_update_benchmark_model_association_sets_priority( mocker, server, mlcube_uid, benchmark_uid, priority ): # Arrange @@ -585,7 +587,7 @@ def test_set_mlcube_association_priority_sets_priority( exp_url = f"{full_url}/mlcubes/{mlcube_uid}/benchmarks/{benchmark_uid}/" # Act - server.set_mlcube_association_priority(benchmark_uid, mlcube_uid, priority) + server.update_benchmark_model_association(benchmark_uid, mlcube_uid, priority) # Assert spy.assert_called_once_with(exp_url, json={"priority": priority}) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 97eff0b57..7d8042fde 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import re import os import signal @@ -15,15 +16,11 @@ import shutil from pexpect import spawn from datetime import datetime -from pydantic.datetime_parse import parse_datetime from typing import List from colorama import Fore, Style from pexpect.exceptions import TIMEOUT from git import Repo, GitCommandError import medperf.config as config -from medperf.cryptography.participant import generate_csr -from medperf.cryptography.io import get_csr_hash, write_key -from medperf.cryptography.utils import cert_to_str from medperf.exceptions import ExecutionError, MedperfException @@ -407,30 +404,6 @@ def get_cube_image_name(cube_path: str) -> str: raise MedperfException(msg) -def filter_latest_associations(associations, entity_key): - """Given a list of entity-benchmark associations, this function - retrieves a list containing the latest association of each - entity instance. - - Args: - associations (list[dict]): the list of associations - entity_key (str): either "dataset" or "model_mlcube" - - Returns: - list[dict]: the list containing the latest association of each - entity instance. - """ - - associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) - latest_associations = {} - for assoc in associations: - entity_id = assoc[entity_key] - latest_associations[entity_id] = assoc - - latest_associations = list(latest_associations.values()) - return latest_associations - - def check_for_updates() -> None: """Check if the current branch is up-to-date with its remote counterpart using GitPython.""" repo = Repo(config.BASE_DIR) @@ -511,46 +484,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False -def get_dataset_common_name(email, dataset_id, exp_id): - return f"{email}_d{dataset_id}_e{exp_id}".lower() - - -def generate_data_csr(email, data_uid, training_exp_id): - common_name = get_dataset_common_name(email, data_uid, training_exp_id) - private_key, csr = generate_csr(common_name, server=False) - - # store private key - target_folder = os.path.join( - config.training_folder, - str(training_exp_id), - config.data_cert_folder, - str(data_uid), - ) - os.makedirs(target_folder, exist_ok=True) - target_path = os.path.join(target_folder, "key.key") - write_key(private_key, target_path) - - csr_hash = get_csr_hash(csr) - csr_str = cert_to_str(csr) - return csr_str, csr_hash - - -def generate_agg_csr(training_exp_id, agg_address, agg_id): - common_name = f"{agg_address}".lower() - private_key, csr = generate_csr(common_name, server=True) - - # store private key - target_folder = os.path.join( - config.training_folder, - str(training_exp_id), - config.agg_cert_folder, - str(agg_id), - ) - os.makedirs(target_folder, exist_ok=True) - target_path = os.path.join(target_folder, "key.key") - write_key(private_key, target_path) +def get_pki_assets_path(common_name: str, ca_name: str): + # Base64 encoding is used just to avoid special characters used in emails + # and server domains/ipaddresses. + cn_encoded = base64.b64encode(common_name.encode("utf-8")).decode("utf-8") + return os.path.join(config.pki_assets, cn_encoded, ca_name) - csr_hash = get_csr_hash(csr) - csr_str = cert_to_str(csr) - return csr_str, csr_hash +def get_participant_label(email, data_id): + return f"{email}_d{data_id}" diff --git a/examples/fl/cert/mlcube/mlcube.yaml b/examples/fl/cert/mlcube/mlcube.yaml new file mode 100644 index 000000000..700782b00 --- /dev/null +++ b/examples/fl/cert/mlcube/mlcube.yaml @@ -0,0 +1,38 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-step-cli:0.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + trust: + entrypoint: /bin/bash /mlcube_project/trust.sh + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_client_cert: + entrypoint: /bin/bash /mlcube_project/get_cert.sh + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_server_cert: + entrypoint: /bin/bash /mlcube_project/get_cert.sh + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ diff --git a/examples/fl/cert/mlcube/workspace/ca_config.json b/examples/fl/cert/mlcube/workspace/ca_config.json new file mode 100644 index 000000000..fbf4b8696 --- /dev/null +++ b/examples/fl/cert/mlcube/workspace/ca_config.json @@ -0,0 +1,7 @@ +{ + "address": "https://example.com", + "port": 443, + "fingerprint": "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + "client_provisioner": "auth0", + "server_provisioner": "acme" +} \ No newline at end of file diff --git a/examples/fl/cert/project/Dockerfile b/examples/fl/cert/project/Dockerfile new file mode 100644 index 000000000..9a36c1bd2 --- /dev/null +++ b/examples/fl/cert/project/Dockerfile @@ -0,0 +1,7 @@ +FROM smallstep/step-cli:0.26.1 + +RUN apt-get update && apt-get install jq-y + +COPY . /mlcube_project + +ENTRYPOINT ["/bin/bash"] diff --git a/examples/fl/cert/project/get_cert.sh b/examples/fl/cert/project/get_cert.sh new file mode 100644 index 000000000..83c8bac40 --- /dev/null +++ b/examples/fl/cert/project/get_cert.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +# Read arguments +while [ "${1:-}" != "" ]; do + case "$1" in + "--ca_config"*) + ca_config="${1#*=}" + ;; + "--pki_assets"*) + pki_assets="${1#*=}" + ;; + *) + task=$1 + ;; + esac + shift +done + +# validate arguments +if [ -z "$ca_config" ]; then + echo "--ca_config is required" + exit 1 +fi + +if [ -z "$pki_assets" ]; then + echo "--pki_assets is required" + exit 1 +fi + +if [ -z "$MEDPERF_INPUT_CN" ]; then + echo "MEDPERF_INPUT_CN environment variable must be set" + exit 1 +fi + +CA_ADDRESS=$(jq -r '.address' $ca_config) +CA_PORT=$(jq -r '.port' $ca_config) +CA_FINGERPRINT=$(jq -r '.fingerprint' $ca_config) +CA_CLIENT_PROVISIONER=$(jq -r '.client_provisioner' $ca_config) +CA_SERVER_PROVISIONER=$(jq -r '.server_provisioner' $ca_config) + +if [ "$task" = "get_server_cert" ]; then + PROVISIONER_ARGS="--provisioner $CA_SERVER_PROVISIONER" +elif [ "$task" = "get_client_cert" ]; then + PROVISIONER_ARGS="--provisioner $CA_CLIENT_PROVISIONER --console" +else + echo "Invalid task: Task should be get_server_cert or get_client_cert" + exit 1 +fi + +cert_path=$pki_assets/crt.crt +key_path=$pki_assets/key.key + +# trust the CA. +step ca bootstrap --ca-url $CA_ADDRESS:$CA_PORT \ + --fingerprint $CA_FINGERPRINT + +# generate private key and ask for a certificate +# $STEPPATH/certs/root_ca.crt is the path where step-ca stores the trusted ca cert by default +step ca certificate --ca-url $CA_ADDRESS:$CA_PORT \ + --root $STEPPATH/certs/root_ca.crt \ + $PROVISIONER_ARGS \ + $MEDPERF_INPUT_CN $cert_path $key_path diff --git a/examples/fl/cert/project/trust.sh b/examples/fl/cert/project/trust.sh new file mode 100644 index 000000000..7c9ecd4b8 --- /dev/null +++ b/examples/fl/cert/project/trust.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Read arguments +while [ "${1:-}" != "" ]; do + case "$1" in + "--ca_config"*) + ca_config="${1#*=}" + ;; + "--pki_assets"*) + pki_assets="${1#*=}" + ;; + *) + task=$1 + ;; + esac + shift +done + +# validate arguments +if [ -z "$ca_config" ]; then + echo "--ca_config is required" + exit 1 +fi + +if [ -z "$pki_assets" ]; then + echo "--pki_assets is required" + exit 1 +fi + +if [ "$task" != "trust" ]; then + echo "Invalid task: Task should be 'trust'" + exit 1 +fi + +CA_ADDRESS=$(jq -r '.address' $ca_config) +CA_PORT=$(jq -r '.port' $ca_config) +CA_FINGERPRINT=$(jq -r '.fingerprint' $ca_config) + +# trust the CA. +rm -rf $pki_assets/* +step ca root $pki_assets/root_ca.crt --ca-url $CA_ADDRESS:$CA_PORT \ + --fingerprint $CA_FINGERPRINT diff --git a/server/aggregator/models.py b/server/aggregator/models.py index 683e9441f..8947c626d 100644 --- a/server/aggregator/models.py +++ b/server/aggregator/models.py @@ -7,19 +7,19 @@ class Aggregator(models.Model): owner = models.ForeignKey(User, on_delete=models.PROTECT) name = models.CharField(max_length=20, unique=True) - address = models.CharField(max_length=300) - port = models.IntegerField() + config = models.JSONField() aggregation_mlcube = models.ForeignKey( "mlcube.MlCube", on_delete=models.PROTECT, related_name="aggregators", ) + is_valid = models.BooleanField(default=True) metadata = models.JSONField(default=dict, blank=True, null=True) created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) def __str__(self): - return self.address + return self.config class Meta: ordering = ["created_at"] diff --git a/server/ca/models.py b/server/ca/models.py index 322c7b098..b7def44f6 100644 --- a/server/ca/models.py +++ b/server/ca/models.py @@ -7,15 +7,17 @@ class CA(models.Model): owner = models.ForeignKey(User, on_delete=models.PROTECT) name = models.CharField(max_length=20, unique=True) - address = models.CharField(max_length=300) - port = models.IntegerField() - fingerprint = models.TextField() + config = models.JSONField() + client_mlcube = models.ForeignKey("mlcube.MlCube", on_delete=models.PROTECT) + server_mlcube = models.ForeignKey("mlcube.MlCube", on_delete=models.PROTECT) + ca_mlcube = models.ForeignKey("mlcube.MlCube", on_delete=models.PROTECT) + is_valid = models.BooleanField(default=True) metadata = models.JSONField(default=dict, blank=True, null=True) created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) def __str__(self): - return self.address + return self.config class Meta: ordering = ["created_at"] diff --git a/server/dataset/serializers.py b/server/dataset/serializers.py index aaee5aaab..49f6173dd 100644 --- a/server/dataset/serializers.py +++ b/server/dataset/serializers.py @@ -1,5 +1,6 @@ from rest_framework import serializers from .models import Dataset +from user.serializers import UserSerializer class DatasetFullSerializer(serializers.ModelSerializer): @@ -60,3 +61,14 @@ def validate(self, data): "User cannot update non editable fields in Operation mode" ) return data + + +class DatasetWithOwnerInfoSerializer(serializers.ModelSerializer): + """This is needed for training to get datasets and their owners + with one API call.""" + + owner = UserSerializer() + + class Meta: + model = Dataset + fields = ["id", "owner"] diff --git a/server/dataset/urls.py b/server/dataset/urls.py index a4186fa64..7b020662e 100644 --- a/server/dataset/urls.py +++ b/server/dataset/urls.py @@ -11,8 +11,8 @@ path("benchmarks/", bviews.BenchmarkDatasetList.as_view()), path("/benchmarks//", bviews.DatasetApproval.as_view()), # path("/benchmarks/", bviews.DatasetBenchmarksList.as_view()), - # path("/training_experiments/", tviews.DatasetExperimentList.as_view()), + # path("/training/", tviews.DatasetExperimentList.as_view()), # NOTE: when activating those two endpoints later, check permissions and write tests - path("training_experiments/", tviews.ExperimentDatasetList.as_view()), - path("/training_experiments//", tviews.DatasetApproval.as_view()), + path("training/", tviews.ExperimentDatasetList.as_view()), + path("/training//", tviews.DatasetApproval.as_view()), ] diff --git a/server/medperf/urls.py b/server/medperf/urls.py index f739ae155..fb68b5d8f 100644 --- a/server/medperf/urls.py +++ b/server/medperf/urls.py @@ -38,5 +38,6 @@ path("me/", include("utils.urls", namespace=API_VERSION), name="me"), path("training/", include("training.urls", namespace=API_VERSION), name="training"), path("aggregators/", include("aggregator.urls", namespace=API_VERSION), name="aggregator"), + path("cas/", include("ca.urls", namespace=API_VERSION), name="ca") ])), ] diff --git a/server/training/urls.py b/server/training/urls.py index ac3ad96cb..054da57fd 100644 --- a/server/training/urls.py +++ b/server/training/urls.py @@ -1,6 +1,5 @@ -from django.urls import path +from django.urls import path, include from . import views -import trainingevent.views as pviews app_name = "training" @@ -10,6 +9,7 @@ path("/datasets/", views.TrainingDatasetList.as_view()), path("/aggregator/", views.TrainingAggregator.as_view()), path("/ca/", views.TrainingCA.as_view()), - path("/plan/", pviews.EventDetail.as_view()), - path("plans/", pviews.EventList.as_view()), + path("/event/", views.GetTrainingEvent.as_view()), + path("/participants_info/", views.ParticipantsInfo.as_view()), + path("events/", include("trainingevent.urls", namespace=app_name), name="event"), ] diff --git a/server/training/views.py b/server/training/views.py index dd46c10bc..a2c13edac 100644 --- a/server/training/views.py +++ b/server/training/views.py @@ -5,12 +5,17 @@ TrainingExperimentListofDatasetsSerializer, ) from ca.serializers import CASerializer +from trainingevent.serializers import EventDetailSerializer +from dataset.serializers import DatasetWithOwnerInfoSerializer from django.http import Http404 from rest_framework.generics import GenericAPIView from rest_framework.response import Response from rest_framework import status from drf_spectacular.utils import extend_schema +from django.db.models import OuterRef, Subquery +from django.contrib.auth import get_user_model +from dataset.models import Dataset from .models import TrainingExperiment from .serializers import ( WriteTrainingExperimentSerializer, @@ -23,6 +28,8 @@ IsAggregatorOwner, ) +User = get_user_model() + class TrainingExperimentList(GenericAPIView): serializer_class = WriteTrainingExperimentSerializer @@ -125,6 +132,33 @@ def get(self, request, pk, format=None): return Response(serializer.data) +class GetTrainingEvent(GenericAPIView): + permission_classes = [ + IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + ] + serializer_class = EventDetailSerializer + queryset = "" + + def get_object(self, pk): + try: + training_exp = TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + event = training_exp.event + if not event: + raise Http404 + return event + + def get(self, request, pk, format=None): + """ + Retrieve latest event of a training experiment instance. + """ + event = self.get_object(pk) + serializer = EventDetailSerializer(event) + return Response(serializer.data) + + class TrainingExperimentDetail(GenericAPIView): serializer_class = ReadTrainingExperimentSerializer queryset = "" @@ -172,3 +206,35 @@ def delete(self, request, pk, format=None): training_exp = self.get_object(pk) training_exp.delete() return Response(status=status.HTTP_204_NO_CONTENT) + + +class ParticipantsInfo(GenericAPIView): + permission_classes = [IsAdmin | IsExpOwner] + serializer_class = TrainingExperimentListofDatasetsSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingExperiment.objects.get(pk=pk) + except TrainingExperiment.DoesNotExist: + raise Http404 + + def get(self, request, pk, format=None): + """ + Retrieve datasets associated with a training experiment instance. + """ + training_exp = self.get_object(pk) + latest_datasets_assocs_status = ( + training_exp.traindataset_association_set.all() + .filter(dataset__id=OuterRef("id")) + .order_by("-created_at")[:1] + .values("approval_status") + ) + datasets_with_users = ( + Dataset.objects.all() + .annotate(assoc_status=Subquery(latest_datasets_assocs_status)) + .filter(assoc_status="APPROVED") + ) + datasets_with_users = self.paginate_queryset(datasets_with_users) + serializer = DatasetWithOwnerInfoSerializer(datasets_with_users, many=True) + return self.get_paginated_response(serializer.data) diff --git a/server/trainingevent/models.py b/server/trainingevent/models.py index 61376d77d..6c7b6bc7a 100644 --- a/server/trainingevent/models.py +++ b/server/trainingevent/models.py @@ -1,9 +1,15 @@ from django.db import models from training.models import TrainingExperiment +from django.contrib.auth import get_user_model + +User = get_user_model() # Create your models here. class TrainingEvent(models.Model): + name = models.CharField(max_length=20, unique=True) + owner = models.ForeignKey(User, on_delete=models.PROTECT) + is_valid = models.BooleanField(default=True) finished = models.BooleanField(default=False) training_exp = models.ForeignKey( TrainingExperiment, on_delete=models.PROTECT, related_name="events" @@ -11,6 +17,7 @@ class TrainingEvent(models.Model): participants = models.JSONField() report = models.JSONField(blank=True, null=True) created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) finished_at = models.DateTimeField(null=True, blank=True) class Meta: diff --git a/server/trainingevent/serializers.py b/server/trainingevent/serializers.py index ca2080e3d..46f29243c 100644 --- a/server/trainingevent/serializers.py +++ b/server/trainingevent/serializers.py @@ -8,7 +8,7 @@ class EventSerializer(serializers.ModelSerializer): class Meta: model = TrainingEvent fields = "__all__" - read_only_fields = ["finished", "finished_at", "report"] + read_only_fields = ["finished", "finished_at", "report", "owner"] def validate(self, data): training_exp = TrainingExperiment.objects.get(pk=data["training_exp"]) @@ -39,7 +39,13 @@ class EventDetailSerializer(serializers.ModelSerializer): class Meta: model = TrainingEvent fields = "__all__" - read_only_fields = ["finished_at", "training_exp", "participants", "finished"] + read_only_fields = [ + "finished_at", + "training_exp", + "participants", + "finished", + "owner", + ] def validate(self, data): if self.instance.finished: diff --git a/server/trainingevent/urls.py b/server/trainingevent/urls.py new file mode 100644 index 000000000..e23cd188c --- /dev/null +++ b/server/trainingevent/urls.py @@ -0,0 +1,9 @@ +from django.urls import path +from . import views + +app_name = "events" + +urlpatterns = [ + path("", views.EventList.as_view()), + path("/", views.EventDetail.as_view()), +] diff --git a/server/trainingevent/views.py b/server/trainingevent/views.py index f9f3d2309..f73a9d28c 100644 --- a/server/trainingevent/views.py +++ b/server/trainingevent/views.py @@ -1,4 +1,4 @@ -from training.models import TrainingExperiment +from .models import TrainingEvent from django.http import Http404 from rest_framework.generics import GenericAPIView from rest_framework.response import Response @@ -19,10 +19,19 @@ def post(self, request, format=None): """ serializer = EventSerializer(data=request.data) if serializer.is_valid(): - serializer.save() + serializer.save(owner=request.user) return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def get(self, request, format=None): + """ + get all events + """ + events = TrainingEvent.objects.all() + events = self.paginate_queryset(events) + serializer = EventSerializer(events, many=True) + return self.get_paginated_response(serializer.data) + class EventDetail(GenericAPIView): serializer_class = EventDetailSerializer @@ -33,30 +42,25 @@ def get_permissions(self): self.permission_classes = [IsAdmin | IsExpOwner | IsAggregatorOwner] return super(self.__class__, self).get_permissions() - def get_object(self, tid): + def get_object(self, pk): try: - training_exp = TrainingExperiment.objects.get(pk=tid) - except TrainingExperiment.DoesNotExist: - raise Http404 - - event = training_exp.event - if not event: + return TrainingEvent.objects.get(pk=pk) + except TrainingEvent.DoesNotExist: raise Http404 - return event - def get(self, request, tid, format=None): + def get(self, request, pk, format=None): """ - Retrieve latest event of a training experiment + Retrieve an event """ - event = self.get_object(tid) + event = self.get_object(pk) serializer = EventDetailSerializer(event) return Response(serializer.data) - def put(self, request, tid, format=None): + def put(self, request, pk, format=None): """ - Update latest event of a training experiment + Update an event """ - event = self.get_object(tid) + event = self.get_object(pk) serializer = EventDetailSerializer(event, data=request.data) if serializer.is_valid(): serializer.save() From 33ad9684676067b523a0ef9209ea21d13883e3d6 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Apr 2024 14:46:01 +0200 Subject: [PATCH 045/242] update FL example --- cli/medperf/commands/training/run.py | 2 +- examples/fl/fl/build.sh | 6 ++++++ examples/fl/fl/mlcube/mlcube.yaml | 1 + examples/fl/fl/project/aggregator.py | 7 ++++++- examples/fl/fl/project/hooks.py | 2 ++ examples/fl/fl/project/mlcube.py | 4 ++++ examples/fl/fl/project/utils.py | 2 +- examples/fl/fl/test.sh | 8 ++++---- 8 files changed, 25 insertions(+), 7 deletions(-) diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/training/run.py index a7d48d26c..69b2e97ae 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/training/run.py @@ -77,7 +77,7 @@ def prepare_pki_assets(self): def run_experiment(self): participant_label = get_participant_label(self.user_email, self.dataset.id) - env_dict = {"COLLABORATOR_CN": participant_label} + env_dict = {"MEDPERF_PARTICIPANT_LABEL": participant_label} params = { "data_path": self.dataset.data_path, "labels_path": self.dataset.labels_path, diff --git a/examples/fl/fl/build.sh b/examples/fl/fl/build.sh index d56304274..67cda94a7 100644 --- a/examples/fl/fl/build.sh +++ b/examples/fl/fl/build.sh @@ -1 +1,7 @@ +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/fl/mlcube/mlcube.yaml b/examples/fl/fl/mlcube/mlcube.yaml index 639a81118..65692efbb 100644 --- a/examples/fl/fl/mlcube/mlcube.yaml +++ b/examples/fl/fl/mlcube/mlcube.yaml @@ -36,6 +36,7 @@ tasks: outputs: output_logs: logs/ output_weights: final_weights/ + report_path: { type: "file", default: "report/report.yaml" } generate_plan: parameters: inputs: diff --git a/examples/fl/fl/project/aggregator.py b/examples/fl/fl/project/aggregator.py index 36d8b02f1..4d190e5d1 100644 --- a/examples/fl/fl/project/aggregator.py +++ b/examples/fl/fl/project/aggregator.py @@ -6,7 +6,7 @@ prepare_cols_list, prepare_init_weights, create_workspace, - get_weights_path, + # get_weights_path, ) import os @@ -23,6 +23,7 @@ def start_aggregator( output_weights, plan_path, collaborators, + report_path, ): workspace_folder = os.path.join(output_logs, "workspace") @@ -54,3 +55,7 @@ def start_aggregator( # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) + + # for now create an arbitrary report + with open(report_path, "w") as f: + f.write("agg_accuracy: 1.0") diff --git a/examples/fl/fl/project/hooks.py b/examples/fl/fl/project/hooks.py index 5b124e8b0..dd3960ba4 100644 --- a/examples/fl/fl/project/hooks.py +++ b/examples/fl/fl/project/hooks.py @@ -81,6 +81,7 @@ def aggregator_pre_training_hook( output_weights, plan_path, collaborators, + report_path, ): pass @@ -93,5 +94,6 @@ def aggregator_post_training_hook( output_weights, plan_path, collaborators, + report_path, ): pass diff --git a/examples/fl/fl/project/mlcube.py b/examples/fl/fl/project/mlcube.py index f3cfa640d..9e4a7e728 100644 --- a/examples/fl/fl/project/mlcube.py +++ b/examples/fl/fl/project/mlcube.py @@ -74,6 +74,7 @@ def start_aggregator_( output_weights: str = typer.Option(..., "--output_weights"), plan_path: str = typer.Option(..., "--plan_path"), collaborators: str = typer.Option(..., "--collaborators"), + report_path: str = typer.Option(..., "--report_path"), ): _setup(output_logs) aggregator_pre_training_hook( @@ -84,6 +85,7 @@ def start_aggregator_( output_weights=output_weights, plan_path=plan_path, collaborators=collaborators, + report_path=report_path, ) start_aggregator( input_weights=input_weights, @@ -93,6 +95,7 @@ def start_aggregator_( output_weights=output_weights, plan_path=plan_path, collaborators=collaborators, + report_path=report_path, ) aggregator_post_training_hook( input_weights=input_weights, @@ -102,6 +105,7 @@ def start_aggregator_( output_weights=output_weights, plan_path=plan_path, collaborators=collaborators, + report_path=report_path, ) _teardown(output_logs) diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py index 558d0f8dd..76db41459 100644 --- a/examples/fl/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -24,7 +24,7 @@ def get_aggregator_fqdn(fl_workspace): def get_collaborator_cn(): # TODO: check if there is a way this can cause a collision/race condition # TODO: from inside the file - return os.environ["COLLABORATOR_CN"] + return os.environ["MEDPERF_PARTICIPANT_LABEL"] def get_weights_path(fl_workspace): diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh index c6a77ffed..3a154936a 100644 --- a/examples/fl/fl/test.sh +++ b/examples/fl/fl/test.sh @@ -8,14 +8,14 @@ cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace # Run nodes AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e COLLABORATOR_CN=col1@example.com" -COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e COLLABORATOR_CN=col2@example.com" -COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e COLLABORATOR_CN=col3@example.com" +COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" gnome-terminal -- bash -c "$AGG; bash" gnome-terminal -- bash -c "$COL1; bash" gnome-terminal -- bash -c "$COL2; bash" gnome-terminal -- bash -c "$COL3; bash" -# docker run --env COLLABORATOR_CN=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 # python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 From 45bedcc1f92ca27108dc6503922a58cb3986f8bd Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Apr 2024 15:15:42 +0200 Subject: [PATCH 046/242] update/create missing APIs --- cli/medperf/comms/rest.py | 23 +++++++-- cli/medperf/entities/training_exp.py | 2 +- server/utils/urls.py | 11 +++-- server/utils/views.py | 71 +++++++++++++++++++++++++++- 4 files changed, 97 insertions(+), 10 deletions(-) diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 8f61e006e..6783230c8 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -367,13 +367,13 @@ def get_aggregators(self) -> List[dict]: return self.__get_list(url, error_msg=error_msg) def get_cas(self) -> List[dict]: - """Retrieves all training events + """Retrieves all cas Returns: - List[dict]: List of training events + List[dict]: List of cas """ - url = f"{self.server_url}/training/events/" - error_msg = "Could not retrieve training events" + url = f"{self.server_url}/cas/" + error_msg = "Could not retrieve cas" return self.__get_list(url, error_msg=error_msg) def get_training_events(self) -> List[dict]: @@ -841,7 +841,7 @@ def get_mlcube_datasets(self, mlcube_id: int) -> dict: return self.__get_list(url, error_msg=error_msg) def get_training_datasets_associations(self, training_exp_id: int) -> dict: - """Retrieves all approved datasets for a given training_exp + """Retrieves all datasets for a given training_exp Args: benchmark_id (int): benchmark ID to retrieve results from @@ -865,3 +865,16 @@ def get_benchmark_models_associations(self, benchmark_uid: int) -> List[int]: url = f"{self.server_url}/benchmarks/{benchmark_uid}/models" error_msg = "Could not get benchmark models associations" return self.__get_list(url, error_msg=error_msg) + + def get_training_datasets_with_users(self, training_exp_id: int) -> dict: + """Retrieves all datasets for a given training_exp and their owner information + + Args: + training_exp_id (int): training exp ID + + Returns: + dict: dictionary with the contents of dataset IDs and owner info + """ + url = f"{self.server_url}/training/{training_exp_id}/participants_info/" + error_msg = "Could not get training experiment participants info" + return self.__get_list(url, error_msg=error_msg) diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index f26a587f2..874a2e655 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -106,7 +106,7 @@ def get_datasets_with_users(cls, training_exp_uid: int) -> List[int]: Returns: List[int]: List of mlcube uids """ - uids_with_users = config.comms.get_experiment_datasets_with_users( + uids_with_users = config.comms.get_training_datasets_with_users( training_exp_uid ) return uids_with_users diff --git a/server/utils/urls.py b/server/utils/urls.py index 662a5e504..47f3c35c3 100644 --- a/server/utils/urls.py +++ b/server/utils/urls.py @@ -9,13 +9,18 @@ path("datasets/", views.DatasetList.as_view()), path("mlcubes/", views.MlCubeList.as_view()), path("results/", views.ModelResultList.as_view()), - path("datasets/associations/", views.DatasetAssociationList.as_view()), - path("mlcubes/associations/", views.MlCubeAssociationList.as_view()), path("training/", views.TrainingExperimentList.as_view()), path("aggregators/", views.AggregatorList.as_view()), + path("training/events/", views.TrainingEventList.as_view()), + path("cas/", views.CAList.as_view()), + path("datasets/associations/", views.DatasetAssociationList.as_view()), + path("mlcubes/associations/", views.MlCubeAssociationList.as_view()), path( "datasets/training_associations/", views.DatasetTrainingAssociationList.as_view(), ), - path("aggregators/associations/", views.AggregatorAssociationList.as_view()), + path( + "aggregators/training_associations/", views.AggregatorAssociationList.as_view() + ), + path("cas/training_associations/", views.CAAssociationList.as_view()), ] diff --git a/server/utils/views.py b/server/utils/views.py index 24a00ce29..d80303adc 100644 --- a/server/utils/views.py +++ b/server/utils/views.py @@ -27,6 +27,12 @@ from traindataset_association.serializers import ExperimentDatasetListSerializer from aggregator_association.models import ExperimentAggregator from aggregator_association.serializers import ExperimentAggregatorListSerializer +from ca_association.models import ExperimentCA +from ca_association.serializers import ExperimentCAListSerializer +from trainingevent.serializers import EventDetailSerializer +from ca.serializers import CASerializer +from trainingevent.models import TrainingEvent +from ca.models import CA class User(GenericAPIView): @@ -82,6 +88,26 @@ def get(self, request, format=None): return self.get_paginated_response(serializer.data) +class TrainingEventList(GenericAPIView): + serializer_class = EventDetailSerializer + queryset = "" + + def get_object(self, pk): + try: + return TrainingEvent.objects.filter(owner__id=pk) + except TrainingEvent.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all events owned by the current user + """ + training_events = self.get_object(request.user.id) + training_events = self.paginate_queryset(training_events) + serializer = EventDetailSerializer(training_events, many=True) + return self.get_paginated_response(serializer.data) + + class AggregatorList(GenericAPIView): serializer_class = AggregatorSerializer queryset = "" @@ -102,6 +128,26 @@ def get(self, request, format=None): return self.get_paginated_response(serializer.data) +class CAList(GenericAPIView): + serializer_class = CASerializer + queryset = "" + + def get_object(self, pk): + try: + return CA.objects.filter(owner__id=pk) + except CA.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all CAs owned by the current user + """ + cas = self.get_object(request.user.id) + cas = self.paginate_queryset(cas) + serializer = CASerializer(cas, many=True) + return self.get_paginated_response(serializer.data) + + class MlCubeList(GenericAPIView): serializer_class = MlCubeSerializer queryset = "" @@ -205,6 +251,7 @@ def get(self, request, format=None): serializer = BenchmarkModelListSerializer(benchmarkmodels, many=True) return self.get_paginated_response(serializer.data) + class DatasetTrainingAssociationList(GenericAPIView): serializer_class = ExperimentDatasetListSerializer queryset = "" @@ -227,13 +274,13 @@ def get(self, request, format=None): serializer = ExperimentDatasetListSerializer(experiment_datasets, many=True) return self.get_paginated_response(serializer.data) + class AggregatorAssociationList(GenericAPIView): serializer_class = ExperimentAggregatorListSerializer queryset = "" def get_object(self, pk): try: - # TODO: this retrieves everything (not just latest ones) return ExperimentAggregator.objects.filter( Q(aggregator__owner__id=pk) | Q(training_exp__owner__id=pk) ) @@ -250,6 +297,28 @@ def get(self, request, format=None): return self.get_paginated_response(serializer.data) +class CAAssociationList(GenericAPIView): + serializer_class = ExperimentCAListSerializer + queryset = "" + + def get_object(self, pk): + try: + return ExperimentCA.objects.filter( + Q(ca__owner__id=pk) | Q(training_exp__owner__id=pk) + ) + except ExperimentCA.DoesNotExist: + raise Http404 + + def get(self, request, format=None): + """ + Retrieve all ca associations involving an asset of mine + """ + experiment_cas = self.get_object(request.user.id) + experiment_cas = self.paginate_queryset(experiment_cas) + serializer = ExperimentCAListSerializer(experiment_cas, many=True) + return self.get_paginated_response(serializer.data) + + class ServerAPIVersion(GenericAPIView): permission_classes = (AllowAny,) queryset = "" From 2343ba1eaa0c0f42230aaf1c6b8888c5d6470a06 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Apr 2024 15:47:14 +0200 Subject: [PATCH 047/242] update new migration files --- server/.env.local.local-auth | 10 +++++- server/aggregator/migrations/0001_initial.py | 6 ++-- .../migrations/0001_initial.py | 2 +- .../migrations/0002_initial.py | 2 +- server/ca/migrations/0001_initial.py | 32 ++++++++++++++++--- server/ca/migrations/0002_createmedperfca.py | 22 ++++++++++--- server/ca/models.py | 12 +++++-- .../ca_association/migrations/0001_initial.py | 2 +- .../ca_association/migrations/0002_initial.py | 4 +-- server/medperf/settings.py | 10 +++--- .../migrations/0001_initial.py | 2 +- .../migrations/0002_initial.py | 4 +-- server/training/migrations/0001_initial.py | 2 +- .../trainingevent/migrations/0001_initial.py | 14 +++++++- 14 files changed, 95 insertions(+), 29 deletions(-) diff --git a/server/.env.local.local-auth b/server/.env.local.local-auth index fa5ea9c33..45a131d3d 100644 --- a/server/.env.local.local-auth +++ b/server/.env.local.local-auth @@ -20,4 +20,12 @@ GS_BUCKET_NAME= AUTH_AUDIENCE=https://localhost-localdev/ AUTH_ISSUER=https://localhost:8000/ AUTH_JWK_URL= -AUTH_VERIFYING_KEY="-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAtKO1SzU6N/sZTJmYNk0C\n/5XbK8eWfcKX2HxFl7fr0V++wrXXGsMs9A8hQEbVWtgYbWaOSkXN0ojmcUt1NFcb\nSPYLmOK/oUXVASEbuZAdIi+ByQ1EnIIAmYSKjRBDUQM8wc73Z9AvrjnhrvEHyrIN\nKyXeLnaCKj/r0s5sQA85SngnCWQbZsRQyHysfsQLwguG0SKFF9EfdNJiaoD8lLBo\nqvUQIYi8MXuVAB7O5EomJoZJe7KEeemsLhCnjTlKHcumjnAiRy5Y0rL6aFXgQkg0\nY4NWxMbsIWAplzh2qCs2jEd88mAUJnHkMzeOKhb1Q+tcmg6ZG6GmwT9fujsOjYrn\na/RTx83B1rRVRHHBFsEP4/ctVf2VdARz+RO+mIh5yZsPiqmRSKpHfbKgnkBpQlAj\nwVrzP9HYT11EXGFesLKRt6Oin0I5FkJ1Ji4w680XjeyZ4KInMY87OvQtltIyrZI9\nR9uY9EnpISGYch6kxbVw0GzdQdP/0mUnYlIeWwyvsXsWB/b3pZ9BiQuCMtlxoWlk\naRjWk9dWIZKFL2uhgeNeY5Wh3Qx9EFx8hnz9ohdaNBPB5BNO2qI61NedFrjYN9LF\nSfcGL7iATU1JQS4rDisnyjDikkTHL9B1u6sMrTsoaqi9Dl5b0gC8RnPVnJItasMN\n9HcW8Pfo2Ava4ler7oU47jUCAwEAAQ==\n-----END PUBLIC KEY-----" \ No newline at end of file +AUTH_VERIFYING_KEY="-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAtKO1SzU6N/sZTJmYNk0C\n/5XbK8eWfcKX2HxFl7fr0V++wrXXGsMs9A8hQEbVWtgYbWaOSkXN0ojmcUt1NFcb\nSPYLmOK/oUXVASEbuZAdIi+ByQ1EnIIAmYSKjRBDUQM8wc73Z9AvrjnhrvEHyrIN\nKyXeLnaCKj/r0s5sQA85SngnCWQbZsRQyHysfsQLwguG0SKFF9EfdNJiaoD8lLBo\nqvUQIYi8MXuVAB7O5EomJoZJe7KEeemsLhCnjTlKHcumjnAiRy5Y0rL6aFXgQkg0\nY4NWxMbsIWAplzh2qCs2jEd88mAUJnHkMzeOKhb1Q+tcmg6ZG6GmwT9fujsOjYrn\na/RTx83B1rRVRHHBFsEP4/ctVf2VdARz+RO+mIh5yZsPiqmRSKpHfbKgnkBpQlAj\nwVrzP9HYT11EXGFesLKRt6Oin0I5FkJ1Ji4w680XjeyZ4KInMY87OvQtltIyrZI9\nR9uY9EnpISGYch6kxbVw0GzdQdP/0mUnYlIeWwyvsXsWB/b3pZ9BiQuCMtlxoWlk\naRjWk9dWIZKFL2uhgeNeY5Wh3Qx9EFx8hnz9ohdaNBPB5BNO2qI61NedFrjYN9LF\nSfcGL7iATU1JQS4rDisnyjDikkTHL9B1u6sMrTsoaqi9Dl5b0gC8RnPVnJItasMN\n9HcW8Pfo2Ava4ler7oU47jUCAwEAAQ==\n-----END PUBLIC KEY-----" + +#CA configuration +CA_NAME="MedPerf CA" +CA_CONFIG={"address":"127.0.0.1","port":443,"fingerprint":"fingerprint","client_provisioner":"auth0","server_provisioner":"acme"} +CA_MLCUBE_NAME="MedPerf CA" +CA_MLCUBE_URL="url" +CA_MLCUBE_HASH="hash" +CA_MLCUBE_IMAGE_HASH="hash" \ No newline at end of file diff --git a/server/aggregator/migrations/0001_initial.py b/server/aggregator/migrations/0001_initial.py index 4b4ed8ec2..f6d026115 100644 --- a/server/aggregator/migrations/0001_initial.py +++ b/server/aggregator/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.conf import settings from django.db import migrations, models @@ -28,8 +28,8 @@ class Migration(migrations.Migration): ), ), ("name", models.CharField(max_length=20, unique=True)), - ("address", models.CharField(max_length=300)), - ("port", models.IntegerField()), + ("config", models.JSONField()), + ("is_valid", models.BooleanField(default=True)), ("metadata", models.JSONField(blank=True, default=dict, null=True)), ("created_at", models.DateTimeField(auto_now_add=True)), ("modified_at", models.DateTimeField(auto_now=True)), diff --git a/server/aggregator_association/migrations/0001_initial.py b/server/aggregator_association/migrations/0001_initial.py index 5308cd5fb..56b2e466a 100644 --- a/server/aggregator_association/migrations/0001_initial.py +++ b/server/aggregator_association/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.conf import settings from django.db import migrations, models diff --git a/server/aggregator_association/migrations/0002_initial.py b/server/aggregator_association/migrations/0002_initial.py index 9f0f3f66c..ef70e0e1a 100644 --- a/server/aggregator_association/migrations/0002_initial.py +++ b/server/aggregator_association/migrations/0002_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.db import migrations, models import django.db.models.deletion diff --git a/server/ca/migrations/0001_initial.py b/server/ca/migrations/0001_initial.py index c58b3d149..875545c00 100644 --- a/server/ca/migrations/0001_initial.py +++ b/server/ca/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.conf import settings from django.db import migrations, models @@ -10,6 +10,7 @@ class Migration(migrations.Migration): initial = True dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), migrations.swappable_dependency(settings.AUTH_USER_MODEL), ] @@ -27,12 +28,27 @@ class Migration(migrations.Migration): ), ), ("name", models.CharField(max_length=20, unique=True)), - ("address", models.CharField(max_length=300)), - ("port", models.IntegerField()), - ("fingerprint", models.TextField()), + ("config", models.JSONField()), + ("is_valid", models.BooleanField(default=True)), ("metadata", models.JSONField(blank=True, default=dict, null=True)), ("created_at", models.DateTimeField(auto_now_add=True)), ("modified_at", models.DateTimeField(auto_now=True)), + ( + "ca_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="ca", + to="mlcube.mlcube", + ), + ), + ( + "client_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="ca_client", + to="mlcube.mlcube", + ), + ), ( "owner", models.ForeignKey( @@ -40,6 +56,14 @@ class Migration(migrations.Migration): to=settings.AUTH_USER_MODEL, ), ), + ( + "server_mlcube", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="ca_server", + to="mlcube.mlcube", + ), + ), ], options={ "ordering": ["created_at"], diff --git a/server/ca/migrations/0002_createmedperfca.py b/server/ca/migrations/0002_createmedperfca.py index 3c4a4f57f..f93eb513b 100644 --- a/server/ca/migrations/0002_createmedperfca.py +++ b/server/ca/migrations/0002_createmedperfca.py @@ -4,6 +4,7 @@ from django.db.migrations.state import StateApps from django.conf import settings from ca.models import CA +from mlcube.models import MlCube User = get_user_model() @@ -13,11 +14,20 @@ def createmedperfca(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> Non Dynamically create the configured main CA as part of a migration """ admin_user = User.objects.get(username=settings.SUPERUSER_USERNAME) + ca_mlcube = MlCube.objects.create( + name=settings.CA_MLCUBE_NAME, + git_mlcube_url=settings.CA_MLCUBE_URL, + mlcube_hash=settings.CA_MLCUBE_HASH, + image_hash=settings.CA_MLCUBE_IMAGE_HASH, + owner=admin_user, + state="OPERATION", + ) CA.objects.create( name=settings.CA_NAME, - address=settings.CA_ADDRESS, - port=settings.CA_PORT, - fingerprint=settings.CA_FINGERPRINT, + config=settings.CA_CONFIG, + ca_mlcube=ca_mlcube, + client_mlcube=ca_mlcube, + server_mlcube=ca_mlcube, owner=admin_user, ) @@ -25,5 +35,9 @@ def createmedperfca(apps: StateApps, schema_editor: DatabaseSchemaEditor) -> Non class Migration(migrations.Migration): initial = True - dependencies = [("ca", "0001_initial"), ("user", "0001_createsuperuser")] + dependencies = [ + ("ca", "0001_initial"), + ("user", "0001_createsuperuser"), + ("mlcube", "0002_alter_mlcube_unique_together"), + ] operations = [migrations.RunPython(createmedperfca)] diff --git a/server/ca/models.py b/server/ca/models.py index b7def44f6..1165731fb 100644 --- a/server/ca/models.py +++ b/server/ca/models.py @@ -8,9 +8,15 @@ class CA(models.Model): owner = models.ForeignKey(User, on_delete=models.PROTECT) name = models.CharField(max_length=20, unique=True) config = models.JSONField() - client_mlcube = models.ForeignKey("mlcube.MlCube", on_delete=models.PROTECT) - server_mlcube = models.ForeignKey("mlcube.MlCube", on_delete=models.PROTECT) - ca_mlcube = models.ForeignKey("mlcube.MlCube", on_delete=models.PROTECT) + client_mlcube = models.ForeignKey( + "mlcube.MlCube", on_delete=models.PROTECT, related_name="ca_client" + ) + server_mlcube = models.ForeignKey( + "mlcube.MlCube", on_delete=models.PROTECT, related_name="ca_server" + ) + ca_mlcube = models.ForeignKey( + "mlcube.MlCube", on_delete=models.PROTECT, related_name="ca" + ) is_valid = models.BooleanField(default=True) metadata = models.JSONField(default=dict, blank=True, null=True) created_at = models.DateTimeField(auto_now_add=True) diff --git a/server/ca_association/migrations/0001_initial.py b/server/ca_association/migrations/0001_initial.py index 9de3ccbee..5f9d17362 100644 --- a/server/ca_association/migrations/0001_initial.py +++ b/server/ca_association/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.conf import settings from django.db import migrations, models diff --git a/server/ca_association/migrations/0002_initial.py b/server/ca_association/migrations/0002_initial.py index d56842a1f..848d72c90 100644 --- a/server/ca_association/migrations/0002_initial.py +++ b/server/ca_association/migrations/0002_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.db import migrations, models import django.db.models.deletion @@ -9,8 +9,8 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ("ca_association", "0001_initial"), ("training", "0001_initial"), + ("ca_association", "0001_initial"), ] operations = [ diff --git a/server/medperf/settings.py b/server/medperf/settings.py index 53df5b5f4..4b4710735 100644 --- a/server/medperf/settings.py +++ b/server/medperf/settings.py @@ -60,10 +60,12 @@ SUPERUSER_PASSWORD = env("SUPERUSER_PASSWORD") -CA_NAME = "MedPerf CA" -CA_ADDRESS = env("CA_ADDRESS") -CA_FINGERPRINT = env("CA_FINGERPRINT") -CA_PORT = env("CA_PORT") +CA_NAME = env("CA_NAME") +CA_CONFIG = env.json("CA_CONFIG") +CA_MLCUBE_NAME = env("CA_MLCUBE_NAME") +CA_MLCUBE_URL = env("CA_MLCUBE_URL") +CA_MLCUBE_HASH = env("CA_MLCUBE_HASH") +CA_MLCUBE_IMAGE_HASH = env("CA_MLCUBE_IMAGE_HASH") ALLOWED_HOSTS = env.list("ALLOWED_HOSTS", default=[]) diff --git a/server/traindataset_association/migrations/0001_initial.py b/server/traindataset_association/migrations/0001_initial.py index ac68ba155..7938a2561 100644 --- a/server/traindataset_association/migrations/0001_initial.py +++ b/server/traindataset_association/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.conf import settings from django.db import migrations, models diff --git a/server/traindataset_association/migrations/0002_initial.py b/server/traindataset_association/migrations/0002_initial.py index ec851bd0b..6ab74922e 100644 --- a/server/traindataset_association/migrations/0002_initial.py +++ b/server/traindataset_association/migrations/0002_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.db import migrations, models import django.db.models.deletion @@ -9,8 +9,8 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ("traindataset_association", "0001_initial"), ("training", "0001_initial"), + ("traindataset_association", "0001_initial"), ] operations = [ diff --git a/server/training/migrations/0001_initial.py b/server/training/migrations/0001_initial.py index c4fae0e23..4f5b65dd5 100644 --- a/server/training/migrations/0001_initial.py +++ b/server/training/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 from django.conf import settings from django.db import migrations, models diff --git a/server/trainingevent/migrations/0001_initial.py b/server/trainingevent/migrations/0001_initial.py index 894b79121..3aaa8d673 100644 --- a/server/trainingevent/migrations/0001_initial.py +++ b/server/trainingevent/migrations/0001_initial.py @@ -1,5 +1,6 @@ -# Generated by Django 4.2.11 on 2024-04-23 01:12 +# Generated by Django 4.2.11 on 2024-04-29 13:21 +from django.conf import settings from django.db import migrations, models import django.db.models.deletion @@ -10,6 +11,7 @@ class Migration(migrations.Migration): dependencies = [ ("training", "0001_initial"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), ] operations = [ @@ -25,11 +27,21 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), + ("name", models.CharField(max_length=20, unique=True)), + ("is_valid", models.BooleanField(default=True)), ("finished", models.BooleanField(default=False)), ("participants", models.JSONField()), ("report", models.JSONField(blank=True, null=True)), ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), ("finished_at", models.DateTimeField(blank=True, null=True)), + ( + "owner", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to=settings.AUTH_USER_MODEL, + ), + ), ( "training_exp", models.ForeignKey( From 19c80d88deaad27b353d1cb9bc180757534027aa Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Apr 2024 17:37:01 +0200 Subject: [PATCH 048/242] add mock crt mlcube --- .gitignore | 6 ++- examples/fl/mock_cert/build.sh | 1 + examples/fl/mock_cert/clean.sh | 1 + examples/fl/mock_cert/mlcube/mlcube.yaml | 35 +++++++++++++ .../mock_cert/mlcube/workspace/ca_config.json | 7 +++ examples/fl/mock_cert/project/Dockerfile | 11 +++++ .../fl/mock_cert/project/ca/cert/root.crt | 28 +++++++++++ examples/fl/mock_cert/project/ca/root.key | 40 +++++++++++++++ examples/fl/mock_cert/project/csr.conf | 23 +++++++++ examples/fl/mock_cert/project/mlcube.py | 49 +++++++++++++++++++ .../fl/mock_cert/project/requirements.txt | 2 + examples/fl/mock_cert/project/sign.sh | 31 ++++++++++++ examples/fl/mock_cert/test.sh | 6 +++ 13 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 examples/fl/mock_cert/build.sh create mode 100644 examples/fl/mock_cert/clean.sh create mode 100644 examples/fl/mock_cert/mlcube/mlcube.yaml create mode 100644 examples/fl/mock_cert/mlcube/workspace/ca_config.json create mode 100644 examples/fl/mock_cert/project/Dockerfile create mode 100644 examples/fl/mock_cert/project/ca/cert/root.crt create mode 100644 examples/fl/mock_cert/project/ca/root.key create mode 100644 examples/fl/mock_cert/project/csr.conf create mode 100644 examples/fl/mock_cert/project/mlcube.py create mode 100644 examples/fl/mock_cert/project/requirements.txt create mode 100644 examples/fl/mock_cert/project/sign.sh create mode 100644 examples/fl/mock_cert/test.sh diff --git a/.gitignore b/.gitignore index 3bf5c7327..26d4cdd2e 100644 --- a/.gitignore +++ b/.gitignore @@ -147,4 +147,8 @@ cython_debug/ # Dev Environment Specific .vscode .venv -server/keys \ No newline at end of file +server/keys + +# exclude fl example +!examples/fl/mock_cert/project/ca/root.key +!examples/fl/mock_cert/project/ca/cert/root.crt diff --git a/examples/fl/mock_cert/build.sh b/examples/fl/mock_cert/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/mock_cert/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/mock_cert/clean.sh b/examples/fl/mock_cert/clean.sh new file mode 100644 index 000000000..fa72ddd6d --- /dev/null +++ b/examples/fl/mock_cert/clean.sh @@ -0,0 +1 @@ +rm -rf mlcube/workspace/pki_assets diff --git a/examples/fl/mock_cert/mlcube/mlcube.yaml b/examples/fl/mock_cert/mlcube/mlcube.yaml new file mode 100644 index 000000000..8019d3579 --- /dev/null +++ b/examples/fl/mock_cert/mlcube/mlcube.yaml @@ -0,0 +1,35 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-test-ca:0.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + trust: + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_client_cert: + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ + get_server_cert: + parameters: + inputs: + ca_config: ca_config.json + outputs: + pki_assets: pki_assets/ diff --git a/examples/fl/mock_cert/mlcube/workspace/ca_config.json b/examples/fl/mock_cert/mlcube/workspace/ca_config.json new file mode 100644 index 000000000..bcf246a03 --- /dev/null +++ b/examples/fl/mock_cert/mlcube/workspace/ca_config.json @@ -0,0 +1,7 @@ +{ + "address": "https://127.0.0.1", + "port": 443, + "fingerprint": "fingerprint", + "client_provisioner": "auth0", + "server_provisioner": "acme" +} \ No newline at end of file diff --git a/examples/fl/mock_cert/project/Dockerfile b/examples/fl/mock_cert/project/Dockerfile new file mode 100644 index 000000000..91c477415 --- /dev/null +++ b/examples/fl/mock_cert/project/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.9.16-slim + +COPY ./requirements.txt /mlcube_project/requirements.txt + +RUN pip3 install --no-cache-dir -r /mlcube_project/requirements.txt + +ENV LANG C.UTF-8 + +COPY . /mlcube_project + +ENTRYPOINT ["python3", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/mock_cert/project/ca/cert/root.crt b/examples/fl/mock_cert/project/ca/cert/root.crt new file mode 100644 index 000000000..813cd7165 --- /dev/null +++ b/examples/fl/mock_cert/project/ca/cert/root.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyzCCAzOgAwIBAgIUd8btUDxu7RR87iJZhUjzturqti8wDQYJKoZIhvcNAQEM +BQAwdDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBnNpbXBs +ZTEXMBUGA1UEAwwOU2ltcGxlIFJvb3QgQ0ExEzARBgNVBAoMClNpbXBsZSBJbmMx +FzAVBgNVBAsMDlNpbXBsZSBSb290IENBMCAXDTI0MDQyOTEzNTYxMloYDzIxMjQw +NDA1MTM1NjEyWjB0MRMwEQYKCZImiZPyLGQBGRYDb3JnMRYwFAYKCZImiZPyLGQB +GRYGc2ltcGxlMRcwFQYDVQQDDA5TaW1wbGUgUm9vdCBDQTETMBEGA1UECgwKU2lt +cGxlIEluYzEXMBUGA1UECwwOU2ltcGxlIFJvb3QgQ0EwggGiMA0GCSqGSIb3DQEB +AQUAA4IBjwAwggGKAoIBgQDMqoMT6iE/qRFJ/X+N9pp/WUoYaDDliUYuDdgb0pqq +6uqA50tfYmCOWal1K1Gq/4Hgi0OKsyj0bMemtRNXXH8r8qtjLNNmGmyeZICDe9FT +37gNr9uYWtVuwWpTI9bksxGVg9E0qx0U6fo+Puiu5ImDF/iYy1931ghijbOj0qWQ +M2dQi85baF/6uEHZ18b+c7K/toXCNhzJWrpw88DUyPerhkoe/JTI2kSNKZwULuan +VKazUZ4JIPF6NWhQGb/hcI+tTBkJXlETjrpN8A3hqVp6vpZTiZfXfy5eGmSyypwE +Z1gnSBOuh1EQxOLPXhykMeHaPZ5lZMeAprD/eHzWqw3lgTrcFkPBQMTAUWhHY4Wp +DKKRWZa2gbWf5peYzbpQtL0vgyDnDKgDyMkiXJksf97ITfbbP+VCSxLjYUaPWGmL +w4Ik6hzQmSSdSF/Va364W1tUY5D8D+DClrwg97K73nODifmeenwYHIn9Amdzt9Mh +cdQnbgZFJYDFEU1ZbJSlZfUCAwEAAaNTMFEwHQYDVR0OBBYEFCMAiSxHknRxYn6+ +SGNOAxhxDGbVMB8GA1UdIwQYMBaAFCMAiSxHknRxYn6+SGNOAxhxDGbVMA8GA1Ud +EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggGBAEwfiT0gFgWoQ2NNWCL/bUlN +WrBYlwy6ixL8V97LoJMmiLVq/2fxNm66r6AgKKGJLEXPtEQIT6dwpHIlVYvlvhCw +zC079E4/vNXEldoEiHWxzkeQL1hKFFm36Vm5k8hVoHPIeSgJyGO/e4x10m6mYKqZ +iSKcaSCdx8x7B8EHnve1E3H6LQjSilIfdg80niFdKzjE6v7zrTYDyGUz0Shasnms +oO9KtkhRFe04Fm495EmieoaMA7eT7ojucoZ5dwgMf+wDgfXsZ+hIiCeC9nVB1Waq +Biwv9sIJ8tI19Y5FPDXnaP6alkDm05u4PkE9l+BJ5Ky5VQcjLhyZ4X37W33aUU+z +ng3/c27j8MzXtkNIhO2yUX3Lr2ExJqaLB/hrWlWDJH/yRG3hGNGVl3RdCL82hHD+ +TMd6MKKM3XTqnwvHbUiJw63Xa5upOcoXvcoS6/sDvpJQKnuNB3DZ8LBsvnWPlucP +Ctnnve8XKCvXfNVrV0uXB5rjWIvNZ5eiNJsml8e4tQ== +-----END CERTIFICATE----- diff --git a/examples/fl/mock_cert/project/ca/root.key b/examples/fl/mock_cert/project/ca/root.key new file mode 100644 index 000000000..25617bed6 --- /dev/null +++ b/examples/fl/mock_cert/project/ca/root.key @@ -0,0 +1,40 @@ +-----BEGIN PRIVATE KEY----- +MIIG/wIBADANBgkqhkiG9w0BAQEFAASCBukwggblAgEAAoIBgQDMqoMT6iE/qRFJ +/X+N9pp/WUoYaDDliUYuDdgb0pqq6uqA50tfYmCOWal1K1Gq/4Hgi0OKsyj0bMem +tRNXXH8r8qtjLNNmGmyeZICDe9FT37gNr9uYWtVuwWpTI9bksxGVg9E0qx0U6fo+ +Puiu5ImDF/iYy1931ghijbOj0qWQM2dQi85baF/6uEHZ18b+c7K/toXCNhzJWrpw +88DUyPerhkoe/JTI2kSNKZwULuanVKazUZ4JIPF6NWhQGb/hcI+tTBkJXlETjrpN +8A3hqVp6vpZTiZfXfy5eGmSyypwEZ1gnSBOuh1EQxOLPXhykMeHaPZ5lZMeAprD/ +eHzWqw3lgTrcFkPBQMTAUWhHY4WpDKKRWZa2gbWf5peYzbpQtL0vgyDnDKgDyMki +XJksf97ITfbbP+VCSxLjYUaPWGmLw4Ik6hzQmSSdSF/Va364W1tUY5D8D+DClrwg +97K73nODifmeenwYHIn9Amdzt9MhcdQnbgZFJYDFEU1ZbJSlZfUCAwEAAQKCAYAK +0M/7KMT3tPA29XCHiLVYGYMy5alVLVCfdRfV5eaf4FONhauUBNeOw5ToSOZt9PFg +yRCZWdJw6EwSC+upEuLy6EYVCEoQ++sq4QDG8gTOkToMGckEX3h7++NUisZRG61y +4J5uUW9Iqvy7IV6b2h6c5j1lmwnBLsxj+Oe6C+himx97QDLNiHyEprbEKQUDxArO +2s8YGP1NyWpPjHIaTJcvYfoKHSr3r6EucePPpT2HMOVbqz/WF6mV3btU0FI1kFnP +KvUYJy+qEhZGgDHBGm80Y7MAjV/7Iu34oikQTv42QgBd8CwPODZAJs/VRW5U2OMS +DOj/quLSChknofVb5rEcjz2HXVmilsGoLAjdbt19r16XlwlFSihn73zZ/kfJWud+ +IATJ5FW1A9B64QZGhB45hGJcESqHWYq+x4i9puRL1XtuuV0uJQq79w0SKDSQXlrs +AZ1OzEaRdubFE7M49BU8MoSza9QvNhzPADOewVWlpkKrVPYRYXYgouq4FKpRpFkC +gcEA8vTVWh7m9S48W2qnaUGu4EhY3X0PuHUuH/muWEXC4n36xBRNb7jZkcC4qO1Z +B9bznOWJcVKr5AQ2Cq61DlQesltMBMvo8LCdyyV+9XgBpVtFJ4OlOVex7xsR6/Lp +gVO6H9SC0ej1LAHAK37tTLfejPvuMQSnAYuvmDfLoB8nZX4xaRuRz4iHMxi9ED6f +N2Zdhtutp84DypxFQWxeo8SZshAF5i77L9wEkCeA6JjPMwv2SknEuIl+oX2kyNw7 +0oxXAoHBANena6pl0NHqZV0GlrdeVYC4nQrKtLq9cHs5E/8nLmPH29b3T25wPKGU +jP75S0DBb21LC00slVvp8aNMZIv4WBHcAlwUUj6rutJ0Bvm7ZrPZEAZrfPil7BRY +QG2x7lrI+biLj/7hmNKjflfIi1XSlxlfP99Wy37ImIoebZdKEOqP4M7E8NlhK0Lt +YPGg3qxA+0NQsqk+XKrls3AK1pVq2aZTsfAjH+Z0wSqmLSM4tXT7v0rQFYTCe06E ++NB5TfWwkwKBwQCyArV3zICIUBIlIOX8dwW8iwWhcwpbqm/bOcOGJcb+0DM2C3IZ +U6UF5+Dk1NKQrevcn0mu4FXVQUifVxaNoxDCuaXfNdA82gsjVxvImt8J2u+2Xfxn +IVvbx0fAS0DPYxtSSxB24GsSjU3SELOprGbBga0p+TCsLz6/FtJ5RZpGAMoPKwYQ +uwXkaFHOXzOlEbmhH8AC3S1l/E25+77z2w6JqrfHydB9ZoVpYahPw/a8fh08nQQn ++YXwqPBdww+J2w0CgcEAnFaUOBDd5QBPgbQgGUk7JTkxKDyx7tsdK0fC1mv6Nn4S +QvJBVGfrnJwL52ClDInvFMWdqNIUaXDdK6xbDBn7Bt9/mm9k/GgU5TMWR39zQhiv +hGfyTnRDBLDB7IRcrtYaK46J0paL6tB57HvHf21O+ybRMEFE/2G/LApJGq+oOdQa +fuvJS14lNbzPVfxw0WG+hht/mjBKj948SpKg4+t1ZB4y1ksweirUSu3ztSAMdIV5 +NWxK3Vb8e3zswH3gZagfAoHBAJI6LPi7K8POviGrV2Aw0EHv/Fs3GY+zgiphqf69 +pR7fZqcKYOSwPugw5gYR1l9REpWyD9qfNzKXeuQbwwqDaW7VZlUEpp3mndDYG4tT +63W5h4O/vAmtAdF8oasyDv2iSni6xppI2QmlAoSrDIyCM3HyNOM0l4pHRQ+V1ncE +JXwiXwePyt/lx30ua7VWU442ZfnVfA78cK1AaTl2KRV00veHR1KLBvIkyVE5AHD8 +Ynsi1KB4GfSKoLUr9t/n0hqZ8g== +-----END PRIVATE KEY----- diff --git a/examples/fl/mock_cert/project/csr.conf b/examples/fl/mock_cert/project/csr.conf new file mode 100644 index 000000000..3285aed9f --- /dev/null +++ b/examples/fl/mock_cert/project/csr.conf @@ -0,0 +1,23 @@ +[ req ] +default_bits = 3072 +prompt = no +default_md = sha384 +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[ req_distinguished_name ] +commonName = hasan-hp-zbook-15-g3.home + +[ req_ext ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names + +[ alt_names ] +DNS.1 = hasan-hp-zbook-15-g3.home + +[ v3_client ] +extendedKeyUsage = critical,clientAuth + +[ v3_server ] +extendedKeyUsage = critical,serverAuth diff --git a/examples/fl/mock_cert/project/mlcube.py b/examples/fl/mock_cert/project/mlcube.py new file mode 100644 index 000000000..0fffa51ab --- /dev/null +++ b/examples/fl/mock_cert/project/mlcube.py @@ -0,0 +1,49 @@ +"""MLCube handler file""" + +import typer +import shutil +import json +import os + +app = typer.Typer() + + +def asserts(ca_config): + with open(ca_config) as f: + config = json.load(f) + assert config["address"] == "https://127.0.0.1" + assert config["port"] == 443 + assert config["fingerprint"] == "fingerprint" + assert config["client_provisioner"] == "auth0" + assert config["server_provisioner"] == "acme" + + +@app.command("trust") +def trust( + ca_config: str = typer.Option(..., "--ca_config"), + pki_assets: str = typer.Option(..., "--pki_assets"), +): + asserts(ca_config) + shutil.copytree("/mlcube_project/ca/cert", pki_assets, dirs_exist_ok=True) + + +@app.command("get_client_cert") +def get_client_cert( + ca_config: str = typer.Option(..., "--ca_config"), + pki_assets: str = typer.Option(..., "--pki_assets"), +): + asserts(ca_config) + os.system(f"sh /mlcube_project/sign.sh -o {pki_assets}") + + +@app.command("get_server_cert") +def get_server_cert( + ca_config: str = typer.Option(..., "--ca_config"), + pki_assets: str = typer.Option(..., "--pki_assets"), +): + asserts(ca_config) + os.system(f"sh /mlcube_project/sign.sh -o {pki_assets} -s") + + +if __name__ == "__main__": + app() diff --git a/examples/fl/mock_cert/project/requirements.txt b/examples/fl/mock_cert/project/requirements.txt new file mode 100644 index 000000000..a1662dd93 --- /dev/null +++ b/examples/fl/mock_cert/project/requirements.txt @@ -0,0 +1,2 @@ +typer==0.9.0 +PyYAML==6.0 \ No newline at end of file diff --git a/examples/fl/mock_cert/project/sign.sh b/examples/fl/mock_cert/project/sign.sh new file mode 100644 index 000000000..ffadc5144 --- /dev/null +++ b/examples/fl/mock_cert/project/sign.sh @@ -0,0 +1,31 @@ +while getopts so: flag; do + case "${flag}" in + o) OUT=${OPTARG} ;; + s) EXT="v3_server" ;; + esac +done + +EXT="${EXT:-v3_client}" + +if [ -z "$OUT" ]; then + echo "-o is required" + exit 1 +fi + +if [ -z "$MEDPERF_INPUT_CN" ]; then + echo "MEDPERF_INPUT_CN env var is required" + exit 1 +fi + +CSR_TEMPLATE=/mlcube_project/csr.conf +CA_KEY=/mlcube_project/ca/root.key +CA_CERT=/mlcube_project/ca/cert/root.crt + +sed -i "/^commonName = /c\commonName = $MEDPERF_INPUT_CN" $CSR_TEMPLATE +sed -i "/^DNS\.1 = /c\DNS.1 = $MEDPERF_INPUT_CN" $CSR_TEMPLATE + +openssl genpkey -algorithm RSA -out $OUT/key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key $OUT/key.key -out $OUT/csr.csr -config $CSR_TEMPLATE -extensions $EXT +openssl x509 -req -in $OUT/csr.csr -CA $CA_CERT -CAkey $CA_KEY \ + -CAcreateserial -out $OUT/crt.crt -days 36500 -sha384 +rm $OUT/csr.csr diff --git a/examples/fl/mock_cert/test.sh b/examples/fl/mock_cert/test.sh new file mode 100644 index 000000000..eaa800edf --- /dev/null +++ b/examples/fl/mock_cert/test.sh @@ -0,0 +1,6 @@ +mlcube run --mlcube ./mlcube/mlcube.yaml --task trust +# sh clean.sh +mlcube run --mlcube ./mlcube/mlcube.yaml --task get_client_cert -Pdocker.env_args="-e MEDPERF_INPUT_CN=user@example.com" +sh clean.sh +mlcube run --mlcube ./mlcube/mlcube.yaml --task get_server_cert -Pdocker.env_args="-e MEDPERF_INPUT_CN=https://example.com" +sh clean.sh From ed2d33ec880b81cbc6e369c3684a53425de48fa3 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Apr 2024 17:40:54 +0200 Subject: [PATCH 049/242] update test env --- server/.env.local.local-auth | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/.env.local.local-auth b/server/.env.local.local-auth index 45a131d3d..e0f3acfe4 100644 --- a/server/.env.local.local-auth +++ b/server/.env.local.local-auth @@ -24,8 +24,8 @@ AUTH_VERIFYING_KEY="-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A #CA configuration CA_NAME="MedPerf CA" -CA_CONFIG={"address":"127.0.0.1","port":443,"fingerprint":"fingerprint","client_provisioner":"auth0","server_provisioner":"acme"} +CA_CONFIG={"address":"https://127.0.0.1","port":443,"fingerprint":"fingerprint","client_provisioner":"auth0","server_provisioner":"acme"} CA_MLCUBE_NAME="MedPerf CA" -CA_MLCUBE_URL="url" -CA_MLCUBE_HASH="hash" -CA_MLCUBE_IMAGE_HASH="hash" \ No newline at end of file +CA_MLCUBE_URL="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/mock_cert/mlcube/mlcube.yaml" +CA_MLCUBE_HASH="d3d723fa6e14ea5f3ff1b215c4543295271bebf301d113c4953c5d54310b7dd1" +CA_MLCUBE_IMAGE_HASH="12da9239869a629b9c4fb8c04773219b74efcbeb48380065a0eba6c4f716c122" \ No newline at end of file From bf69385979500e4bfcefd111a58a7dfd1852b341 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 30 Apr 2024 05:31:01 +0200 Subject: [PATCH 050/242] bug fixes --- cli/cli_tests_training.sh | 145 +++++- cli/medperf/commands/aggregator/run.py | 4 +- cli/medperf/commands/association/approval.py | 6 +- .../commands/association/association.py | 8 +- cli/medperf/commands/ca/ca.py | 2 +- cli/medperf/commands/dataset/associate.py | 2 +- cli/medperf/commands/dataset/dataset.py | 16 + .../{training/run.py => dataset/train.py} | 14 +- cli/medperf/commands/mlcube/run.py | 7 +- cli/medperf/commands/training/set_plan.py | 2 +- cli/medperf/commands/training/start_event.py | 18 +- cli/medperf/commands/training/training.py | 19 +- cli/medperf/comms/interface.py | 476 +++++++++--------- cli/medperf/comms/rest.py | 4 +- cli/medperf/config.py | 17 +- cli/medperf/entities/cube.py | 8 +- cli/medperf/entities/event.py | 8 +- cli/medperf/utils.py | 3 +- cli/requirements.txt | 6 +- cli/tests_setup.sh | 5 +- examples/fl/fl/project/utils.py | 16 +- examples/fl/fl/setup_test.sh | 65 +-- examples/fl/fl/setup_test_no_docker.sh | 124 +++++ examples/fl/mock_cert/project/Dockerfile | 2 + examples/fl/mock_cert/project/sign.sh | 11 +- examples/fl/mock_cert/test.sh | 8 +- server/.env.local.local-auth | 2 +- server/training/models.py | 4 +- server/training/views.py | 8 +- server/trainingevent/permissions.py | 37 +- server/trainingevent/serializers.py | 4 +- 31 files changed, 652 insertions(+), 399 deletions(-) rename cli/medperf/commands/{training/run.py => dataset/train.py} (87%) create mode 100644 examples/fl/fl/setup_test_no_docker.sh diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index dfda1f4af..470e92b86 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -94,7 +94,7 @@ medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE --operational checkFailed "Train prep submission failed" PREP_UID=$(medperf mlcube ls | grep trainprep | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) -medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -p $TRAIN_PARAMS -a $TRAIN_WEIGHTS --operational +medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -a $TRAIN_WEIGHTS --operational checkFailed "traincube submission failed" TRAINCUBE_UID=$(medperf mlcube ls | grep traincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) ########################################################## @@ -118,6 +118,17 @@ checkFailed "training exp approval failed" echo "\n" +########################################################## +echo "=====================================" +echo "Associate with ca" +echo "=====================================" +CA_UID=$(medperf ca ls | grep "MedPerf CA" | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) +medperf ca associate -t $TRAINING_UID -c $CA_UID -y +checkFailed "ca association failed" +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Activate aggowner profile" @@ -132,8 +143,9 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -HOSTNAME_=$(hostname -I | cut -d " " -f 1) -medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 +# HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out +HOSTNAME_=$(hostname -A | cut -d " " -f 1) +medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## @@ -150,6 +162,37 @@ checkFailed "aggregator association step failed" echo "\n" +########################################################## +echo "=====================================" +echo "Activate modelowner profile" +echo "=====================================" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Approve aggregator association" +echo "=====================================" +medperf association approve -t $TRAINING_UID -a $AGG_UID +checkFailed "agg association approval failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "submit plan" +echo "=====================================" +medperf training set_plan -t $TRAINING_UID -c $TRAINING_CONFIG -y +checkFailed "submit plan failed" + +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Activate dataowner profile" @@ -195,7 +238,7 @@ echo "\n" echo "=====================================" echo "Running data1 association step" echo "=====================================" -medperf training associate_dataset -d $DSET_1_UID -t $TRAINING_UID -y +medperf dataset associate -d $DSET_1_UID -t $TRAINING_UID -y checkFailed "Data1 association step failed" ########################################################## @@ -246,7 +289,7 @@ echo "\n" echo "=====================================" echo "Running data2 association step" echo "=====================================" -medperf training associate_dataset -d $DSET_2_UID -t $TRAINING_UID -y +medperf dataset associate -d $DSET_2_UID -t $TRAINING_UID -y checkFailed "Data2 association step failed" ########################################################## @@ -264,40 +307,91 @@ echo "\n" ########################################################## echo "=====================================" -echo "Approve aggregator association" +echo "Approve data1 association" echo "=====================================" -medperf training approve_association -t $TRAINING_UID -a $AGG_UID -checkFailed "agg association approval failed" +medperf association approve -t $TRAINING_UID -d $DSET_1_UID +checkFailed "data1 association approval failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Approve data1 association" +echo "Approve data2 association" echo "=====================================" -medperf training approve_association -t $TRAINING_UID -d $DSET_1_UID -checkFailed "data1 association approval failed" +medperf association approve -t $TRAINING_UID -d $DSET_2_UID +checkFailed "data2 association approval failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Approve data2 association" +echo "start event" echo "=====================================" -medperf training approve_association -t $TRAINING_UID -d $DSET_2_UID -checkFailed "data2 association approval failed" +medperf training start_event -n event1 -t $TRAINING_UID -y +checkFailed "start event failed" + ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Lock experiment" +echo "Activate aggowner profile" echo "=====================================" -medperf training lock -t $TRAINING_UID -checkFailed "locking experiment failed" +medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get aggregator cert" +echo "=====================================" +medperf certificate get_server_certificate -t $TRAINING_UID +checkFailed "Get aggregator cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner profile" +echo "=====================================" +medperf profile activate testdata1 +checkFailed "testdata1 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get dataowner cert" +echo "=====================================" +medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get dataowner cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate dataowner2 profile" +echo "=====================================" +medperf profile activate testdata2 +checkFailed "testdata2 profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get dataowner2 cert" +echo "=====================================" +medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get dataowner2 cert failed" ########################################################## echo "\n" @@ -316,7 +410,7 @@ echo "\n" echo "=====================================" echo "Starting aggregator" echo "=====================================" -medperf aggregator start -a $AGG_UID -t $TRAINING_UID agg.log 2>&1 & +medperf aggregator start -t $TRAINING_UID agg.log 2>&1 & AGG_PID=$! # sleep so that the mlcube is run before we change profiles @@ -344,7 +438,7 @@ echo "\n" echo "=====================================" echo "Starting training with data1" echo "=====================================" -medperf training run -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & +medperf dataset train -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & COL1_PID=$! # sleep so that the mlcube is run before we change profiles @@ -372,7 +466,7 @@ echo "\n" echo "=====================================" echo "Starting training with data2" echo "=====================================" -medperf training run -d $DSET_2_UID -t $TRAINING_UID +medperf dataset train -d $DSET_2_UID -t $TRAINING_UID checkFailed "data2 training failed" ########################################################## @@ -396,6 +490,17 @@ checkFailed "aggregator didn't exit successfully" echo "\n" +########################################################## +echo "=====================================" +echo "close event" +echo "=====================================" +medperf training close_event -t $TRAINING_UID -y +checkFailed "close event failed" + +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Logout users" diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index edf129cda..03eedd616 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -37,7 +37,7 @@ def prepare(self): self.event = TrainingEvent.from_experiment(self.training_exp_id) def validate(self): - if self.event.finished(): + if self.event.finished: msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) @@ -71,7 +71,7 @@ def run_experiment(self): "ca_cert_folder": self.ca.pki_assets, "plan_path": self.training_exp.plan_path, "collaborators": self.event.participants_list_path, - "output_logs": self.event.out_logs, + "output_logs": self.event.agg_out_logs, "output_weights": self.event.out_weights, "report_path": self.event.report_path, } diff --git a/cli/medperf/commands/association/approval.py b/cli/medperf/commands/association/approval.py index df486b420..ec7fe7999 100644 --- a/cli/medperf/commands/association/approval.py +++ b/cli/medperf/commands/association/approval.py @@ -47,11 +47,11 @@ def run( if training_exp_uid: if dataset_uid: comms.update_training_dataset_association( - benchmark_uid, dataset_uid, update + training_exp_uid, dataset_uid, update ) if aggregator_uid: comms.update_training_aggregator_association( - benchmark_uid, mlcube_uid, update + training_exp_uid, aggregator_uid, update ) if ca_uid: - comms.update_training_ca_association(benchmark_uid, mlcube_uid, update) + comms.update_training_ca_association(training_exp_uid, ca_uid, update) diff --git a/cli/medperf/commands/association/association.py b/cli/medperf/commands/association/association.py index 2e3bca2ac..b97255c72 100644 --- a/cli/medperf/commands/association/association.py +++ b/cli/medperf/commands/association/association.py @@ -43,9 +43,9 @@ def list( @app.command("approve") @clean_except def approve( - benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), + benchmark_uid: int = typer.Option(None, "--benchmark", "-b", help="Benchmark UID"), training_exp_uid: int = typer.Option( - ..., "--training_exp", "-t", help="Training exp UID" + None, "--training_exp", "-t", help="Training exp UID" ), dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"), mlcube_uid: int = typer.Option(None, "--mlcube", "-m", help="MLCube UID"), @@ -76,9 +76,9 @@ def approve( @app.command("reject") @clean_except def reject( - benchmark_uid: int = typer.Option(..., "--benchmark", "-b", help="Benchmark UID"), + benchmark_uid: int = typer.Option(None, "--benchmark", "-b", help="Benchmark UID"), training_exp_uid: int = typer.Option( - ..., "--training_exp", "-t", help="Training exp UID" + None, "--training_exp", "-t", help="Training exp UID" ), dataset_uid: int = typer.Option(None, "--dataset", "-d", help="Dataset UID"), mlcube_uid: int = typer.Option(None, "--mlcube", "-m", help="MLCube UID"), diff --git a/cli/medperf/commands/ca/ca.py b/cli/medperf/commands/ca/ca.py index 167a072c9..580f00822 100644 --- a/cli/medperf/commands/ca/ca.py +++ b/cli/medperf/commands/ca/ca.py @@ -43,7 +43,7 @@ def submit( @app.command("associate") @clean_except def associate( - ca_id: int = typer.Option(..., "--ca_id", "-a", help="UID of CA to associate with"), + ca_id: int = typer.Option(..., "--ca_id", "-c", help="UID of CA to associate with"), training_exp_id: int = typer.Option( ..., "--training_exp_id", "-t", help="UID of training exp to associate with" ), diff --git a/cli/medperf/commands/dataset/associate.py b/cli/medperf/commands/dataset/associate.py index a9de70b80..e338bc831 100644 --- a/cli/medperf/commands/dataset/associate.py +++ b/cli/medperf/commands/dataset/associate.py @@ -28,4 +28,4 @@ def run( raise InvalidArgumentError( "no_cache argument is only valid when associating with a benchmark" ) - AssociateTrainingDataset.run(data_uid, benchmark_uid, approved=approved) + AssociateTrainingDataset.run(data_uid, training_exp_uid, approved=approved) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index d3243ec0d..ed1e710f7 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -10,6 +10,7 @@ from medperf.commands.dataset.prepare import DataPreparation from medperf.commands.dataset.set_operational import DatasetSetOperational from medperf.commands.dataset.associate import AssociateDataset +from medperf.commands.dataset.train import TrainingExecution app = typer.Typer() @@ -144,6 +145,21 @@ def associate( ui.print("✅ Done!") +@app.command("train") +@clean_except +def train( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + data_uid: int = typer.Option( + ..., "--data_uid", "-d", help="Registered Dataset UID" + ), +): + """Runs training""" + TrainingExecution.run(training_exp_id, data_uid) + config.ui.print("✅ Done!") + + @app.command("view") @clean_except def view( diff --git a/cli/medperf/commands/training/run.py b/cli/medperf/commands/dataset/train.py similarity index 87% rename from cli/medperf/commands/training/run.py rename to cli/medperf/commands/dataset/train.py index 69b2e97ae..bb67c09af 100644 --- a/cli/medperf/commands/training/run.py +++ b/cli/medperf/commands/dataset/train.py @@ -1,3 +1,4 @@ +import os from medperf import config from medperf.account_management.account_management import get_medperf_user_data from medperf.entities.ca import CA @@ -48,13 +49,14 @@ def validate(self): msg = "The provided dataset is not operational." raise InvalidArgumentError(msg) - if self.event.finished(): + if self.event.finished: msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) - if self.dataset.id not in self.training_exp.get_datasets_uids(): - msg = "The provided dataset is not associated." - raise InvalidArgumentError(msg) + # TODO: Do we need this? This basically would make participants list public to them + # if self.dataset.id not in TrainingExp.get_datasets_uids(self.training_exp_id): + # msg = "The provided dataset is not associated." + # raise InvalidArgumentError(msg) def prepare_training_cube(self): self.cube = self.__get_cube(self.training_exp.fl_mlcube, "FL") @@ -71,7 +73,7 @@ def prepare_plan(self): def prepare_pki_assets(self): ca = CA.from_experiment(self.training_exp_id) - trust(ca) + # trust(ca) self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) self.ca = ca @@ -84,7 +86,7 @@ def run_experiment(self): "node_cert_folder": self.dataset_pki_assets, "ca_cert_folder": self.ca.pki_assets, "plan_path": self.training_exp.plan_path, - "output_logs": self.event.out_logs, + "output_logs": os.path.join(self.event.col_out_logs, str(self.dataset.id)), } self.ui.text = "Running Training" diff --git a/cli/medperf/commands/mlcube/run.py b/cli/medperf/commands/mlcube/run.py index b239b4672..75c9fb19e 100644 --- a/cli/medperf/commands/mlcube/run.py +++ b/cli/medperf/commands/mlcube/run.py @@ -5,7 +5,8 @@ def run_mlcube(mlcube_path, task, out_logs, params, port, env): c = TestCube() - c.path = mlcube_path - c.cube_path = os.path.join(c.path, config.cube_filename) - c.params_path = os.path.join(c.path, config.params_filename) + c.cube_path = os.path.join(mlcube_path, config.cube_filename) + c.params_path = os.path.join( + mlcube_path, config.workspace_path, config.params_filename + ) c.run(task, out_logs, port=port, env_dict=env, **params) diff --git a/cli/medperf/commands/training/set_plan.py b/cli/medperf/commands/training/set_plan.py index f94123f05..1c3e47365 100644 --- a/cli/medperf/commands/training/set_plan.py +++ b/cli/medperf/commands/training/set_plan.py @@ -31,7 +31,7 @@ def __init__(self, training_exp_id: int, training_config_path: str, approval: bo self.ui = config.ui self.training_exp_id = training_exp_id self.training_config_path = training_config_path - self.approval = approval + self.approved = approval self.plan_out_path = generate_tmp_path() def validate(self): diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py index 13853bef6..e24c2fa48 100644 --- a/cli/medperf/commands/training/start_event.py +++ b/cli/medperf/commands/training/start_event.py @@ -6,16 +6,17 @@ class StartEvent: @classmethod - def run(cls, training_exp_id: int, approval: bool = False): - submission = cls(training_exp_id, approval) + def run(cls, training_exp_id: int, name: str, approval: bool = False): + submission = cls(training_exp_id, name, approval) submission.prepare() submission.validate() submission.create_participants_list() - submission.submit() - submission.write() + updated_body = submission.submit() + submission.write(updated_body) - def __init__(self, training_exp_id: int, approval): + def __init__(self, training_exp_id: int, name: str, approval): self.training_exp_id = training_exp_id + self.name = name self.approved = approval def prepare(self): @@ -46,7 +47,9 @@ def submit(self): self.approved = self.approved or approval_prompt(msg) self.event = TrainingEvent( - training_exp=self.training_exp_id, participants=self.participants_list + name=self.name, + training_exp=self.training_exp_id, + participants=self.participants_list, ) if self.approved: updated_body = self.event.upload() @@ -55,4 +58,5 @@ def submit(self): raise CleanExit("Event creation cancelled") def write(self, updated_body): - self.event.write(updated_body) + event = TrainingEvent(**updated_body) + event.write() diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index 5efa629a4..f42adbf19 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -6,7 +6,6 @@ from medperf.decorators import clean_except from medperf.commands.training.submit import SubmitTrainingExp -from medperf.commands.training.run import TrainingExecution from medperf.commands.training.set_plan import SetPlan from medperf.commands.training.start_event import StartEvent from medperf.commands.training.close_event import CloseEvent @@ -72,10 +71,11 @@ def start_event( training_exp_id: int = typer.Option( ..., "--training_exp_id", "-t", help="UID of the desired benchmark" ), + name: str = typer.Option(..., "--name", "-n", help="Name of the benchmark"), approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - StartEvent.run(training_exp_id, approval) + StartEvent.run(training_exp_id, name, approval) config.ui.print("✅ Done!") @@ -106,21 +106,6 @@ def cancel_event( config.ui.print("✅ Done!") -@app.command("run") -@clean_except -def run( - training_exp_id: int = typer.Option( - ..., "--training_exp_id", "-t", help="UID of the desired benchmark" - ), - data_uid: int = typer.Option( - ..., "--data_uid", "-d", help="Registered Dataset UID" - ), -): - """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - TrainingExecution.run(training_exp_id, data_uid) - config.ui.print("✅ Done!") - - @app.command("ls") @clean_except def list( diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index a99127a2b..45516034f 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -1,4 +1,4 @@ -from typing import List +# from typing import List from abc import ABC, abstractmethod @@ -13,275 +13,275 @@ def __init__(self, source: str): token (str, Optional): authentication token to be used throughout communication. Defaults to None. """ - @classmethod - @abstractmethod - def parse_url(self, url: str) -> str: - """Parse the source URL so that it can be used by the comms implementation. - It should handle protocols and versioning to be able to communicate with the API. - - Args: - url (str): base URL - - Returns: - str: parsed URL with protocol and version - """ - - @abstractmethod - def get_current_user(self): - """Retrieve the currently-authenticated user information""" - - @abstractmethod - def get_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks in the platform. - - Returns: - List[dict]: all benchmarks information. - """ - - @abstractmethod - def get_benchmark(self, benchmark_uid: int) -> dict: - """Retrieves the benchmark specification file from the server - - Args: - benchmark_uid (int): uid for the desired benchmark - - Returns: - dict: benchmark specification - """ - - @abstractmethod - def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: - """Retrieves all the model associations of a benchmark. - - Args: - benchmark_uid (int): UID of the desired benchmark - - Returns: - list[int]: List of benchmark model associations - """ - - @abstractmethod - def get_user_benchmarks(self) -> List[dict]: - """Retrieves all benchmarks created by the user - - Returns: - List[dict]: Benchmarks data - """ - - @abstractmethod - def get_cubes(self) -> List[dict]: - """Retrieves all MLCubes in the platform - - Returns: - List[dict]: List containing the data of all MLCubes - """ - - @abstractmethod - def get_cube_metadata(self, cube_uid: int) -> dict: - """Retrieves metadata about the specified cube - - Args: - cube_uid (int): UID of the desired cube. - - Returns: - dict: Dictionary containing url and hashes for the cube files - """ - - @abstractmethod - def get_user_cubes(self) -> List[dict]: - """Retrieves metadata from all cubes registered by the user - - Returns: - List[dict]: List of dictionaries containing the mlcubes registration information - """ - - @abstractmethod - def upload_benchmark(self, benchmark_dict: dict) -> int: - """Uploads a new benchmark to the server. - - Args: - benchmark_dict (dict): benchmark_data to be uploaded - - Returns: - int: UID of newly created benchmark - """ - - @abstractmethod - def upload_mlcube(self, mlcube_body: dict) -> int: - """Uploads an MLCube instance to the platform - - Args: - mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes + # @classmethod + # @abstractmethod + # def parse_url(self, url: str) -> str: + # """Parse the source URL so that it can be used by the comms implementation. + # It should handle protocols and versioning to be able to communicate with the API. - Returns: - int: id of the created mlcube instance on the platform - """ + # Args: + # url (str): base URL - @abstractmethod - def get_datasets(self) -> List[dict]: - """Retrieves all datasets in the platform + # Returns: + # str: parsed URL with protocol and version + # """ - Returns: - List[dict]: List of data from all datasets - """ + # @abstractmethod + # def get_current_user(self): + # """Retrieve the currently-authenticated user information""" - @abstractmethod - def get_dataset(self, dset_uid: str) -> dict: - """Retrieves a specific dataset + # @abstractmethod + # def get_benchmarks(self) -> List[dict]: + # """Retrieves all benchmarks in the platform. - Args: - dset_uid (str): Dataset UID + # Returns: + # List[dict]: all benchmarks information. + # """ - Returns: - dict: Dataset metadata - """ + # @abstractmethod + # def get_benchmark(self, benchmark_uid: int) -> dict: + # """Retrieves the benchmark specification file from the server - @abstractmethod - def get_user_datasets(self) -> dict: - """Retrieves all datasets registered by the user + # Args: + # benchmark_uid (int): uid for the desired benchmark - Returns: - dict: dictionary with the contents of each dataset registration query - """ + # Returns: + # dict: benchmark specification + # """ - @abstractmethod - def upload_dataset(self, reg_dict: dict) -> int: - """Uploads registration data to the server, under the sha name of the file. + # @abstractmethod + # def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: + # """Retrieves all the model associations of a benchmark. - Args: - reg_dict (dict): Dictionary containing registration information. + # Args: + # benchmark_uid (int): UID of the desired benchmark - Returns: - int: id of the created dataset registration. - """ + # Returns: + # list[int]: List of benchmark model associations + # """ - @abstractmethod - def get_results(self) -> List[dict]: - """Retrieves all results + # @abstractmethod + # def get_user_benchmarks(self) -> List[dict]: + # """Retrieves all benchmarks created by the user - Returns: - List[dict]: List of results - """ + # Returns: + # List[dict]: Benchmarks data + # """ - @abstractmethod - def get_result(self, result_uid: str) -> dict: - """Retrieves a specific result data + # @abstractmethod + # def get_cubes(self) -> List[dict]: + # """Retrieves all MLCubes in the platform - Args: - result_uid (str): Result UID + # Returns: + # List[dict]: List containing the data of all MLCubes + # """ - Returns: - dict: Result metadata - """ + # @abstractmethod + # def get_cube_metadata(self, cube_uid: int) -> dict: + # """Retrieves metadata about the specified cube - @abstractmethod - def get_user_results(self) -> dict: - """Retrieves all results registered by the user + # Args: + # cube_uid (int): UID of the desired cube. - Returns: - dict: dictionary with the contents of each dataset registration query - """ + # Returns: + # dict: Dictionary containing url and hashes for the cube files + # """ - @abstractmethod - def get_benchmark_results(self, benchmark_id: int) -> dict: - """Retrieves all results for a given benchmark + # @abstractmethod + # def get_user_cubes(self) -> List[dict]: + # """Retrieves metadata from all cubes registered by the user - Args: - benchmark_id (int): benchmark ID to retrieve results from + # Returns: + # List[dict]: List of dictionaries containing the mlcubes registration information + # """ - Returns: - dict: dictionary with the contents of each result in the specified benchmark - """ + # @abstractmethod + # def upload_benchmark(self, benchmark_dict: dict) -> int: + # """Uploads a new benchmark to the server. - @abstractmethod - def upload_result(self, results_dict: dict) -> int: - """Uploads result to the server. + # Args: + # benchmark_dict (dict): benchmark_data to be uploaded - Args: - results_dict (dict): Dictionary containing results information. + # Returns: + # int: UID of newly created benchmark + # """ - Returns: - int: id of the generated results entry - """ + # @abstractmethod + # def upload_mlcube(self, mlcube_body: dict) -> int: + # """Uploads an MLCube instance to the platform - @abstractmethod - def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}): - """Create a Dataset Benchmark association + # Args: + # mlcube_body (dict): Dictionary containing all the relevant data for creating mlcubes - Args: - data_uid (int): Registered dataset UID - benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. - """ + # Returns: + # int: id of the created mlcube instance on the platform + # """ - @abstractmethod - def associate_cube(self, cube_uid: str, benchmark_uid: int, metadata: dict = {}): - """Create an MLCube-Benchmark association + # @abstractmethod + # def get_datasets(self) -> List[dict]: + # """Retrieves all datasets in the platform - Args: - cube_uid (str): MLCube UID - benchmark_uid (int): Benchmark UID - metadata (dict, optional): Additional metadata. Defaults to {}. - """ + # Returns: + # List[dict]: List of data from all datasets + # """ - @abstractmethod - def set_dataset_association_approval( - self, dataset_uid: str, benchmark_uid: str, status: str - ): - """Approves a dataset association + # @abstractmethod + # def get_dataset(self, dset_uid: str) -> dict: + # """Retrieves a specific dataset - Args: - dataset_uid (str): Dataset UID - benchmark_uid (str): Benchmark UID - status (str): Approval status to set for the association - """ + # Args: + # dset_uid (str): Dataset UID - @abstractmethod - def set_mlcube_association_approval( - self, mlcube_uid: str, benchmark_uid: str, status: str - ): - """Approves an mlcube association + # Returns: + # dict: Dataset metadata + # """ - Args: - mlcube_uid (str): Dataset UID - benchmark_uid (str): Benchmark UID - status (str): Approval status to set for the association - """ + # @abstractmethod + # def get_user_datasets(self) -> dict: + # """Retrieves all datasets registered by the user - @abstractmethod - def get_datasets_associations(self) -> List[dict]: - """Get all dataset associations related to the current user + # Returns: + # dict: dictionary with the contents of each dataset registration query + # """ - Returns: - List[dict]: List containing all associations information - """ + # @abstractmethod + # def upload_dataset(self, reg_dict: dict) -> int: + # """Uploads registration data to the server, under the sha name of the file. - @abstractmethod - def get_cubes_associations(self) -> List[dict]: - """Get all cube associations related to the current user - - Returns: - List[dict]: List containing all associations information - """ - - @abstractmethod - def set_mlcube_association_priority( - self, benchmark_uid: str, mlcube_uid: str, priority: int - ): - """Sets the priority of an mlcube-benchmark association - - Args: - mlcube_uid (str): MLCube UID - benchmark_uid (str): Benchmark UID - priority (int): priority value to set for the association - """ - - @abstractmethod - def update_dataset(self, dataset_id: int, data: dict): - """Updates the contents of a datasets identified by dataset_id to the new data dictionary. - Updates may be partial. - - Args: - dataset_id (int): ID of the dataset to update - data (dict): Updated information of the dataset. - """ + # Args: + # reg_dict (dict): Dictionary containing registration information. + + # Returns: + # int: id of the created dataset registration. + # """ + + # @abstractmethod + # def get_results(self) -> List[dict]: + # """Retrieves all results + + # Returns: + # List[dict]: List of results + # """ + + # @abstractmethod + # def get_result(self, result_uid: str) -> dict: + # """Retrieves a specific result data + + # Args: + # result_uid (str): Result UID + + # Returns: + # dict: Result metadata + # """ + + # @abstractmethod + # def get_user_results(self) -> dict: + # """Retrieves all results registered by the user + + # Returns: + # dict: dictionary with the contents of each dataset registration query + # """ + + # @abstractmethod + # def get_benchmark_results(self, benchmark_id: int) -> dict: + # """Retrieves all results for a given benchmark + + # Args: + # benchmark_id (int): benchmark ID to retrieve results from + + # Returns: + # dict: dictionary with the contents of each result in the specified benchmark + # """ + + # @abstractmethod + # def upload_result(self, results_dict: dict) -> int: + # """Uploads result to the server. + + # Args: + # results_dict (dict): Dictionary containing results information. + + # Returns: + # int: id of the generated results entry + # """ + + # @abstractmethod + # def associate_dset(self, data_uid: int, benchmark_uid: int, metadata: dict = {}): + # """Create a Dataset Benchmark association + + # Args: + # data_uid (int): Registered dataset UID + # benchmark_uid (int): Benchmark UID + # metadata (dict, optional): Additional metadata. Defaults to {}. + # """ + + # @abstractmethod + # def associate_cube(self, cube_uid: str, benchmark_uid: int, metadata: dict = {}): + # """Create an MLCube-Benchmark association + + # Args: + # cube_uid (str): MLCube UID + # benchmark_uid (int): Benchmark UID + # metadata (dict, optional): Additional metadata. Defaults to {}. + # """ + + # @abstractmethod + # def set_dataset_association_approval( + # self, dataset_uid: str, benchmark_uid: str, status: str + # ): + # """Approves a dataset association + + # Args: + # dataset_uid (str): Dataset UID + # benchmark_uid (str): Benchmark UID + # status (str): Approval status to set for the association + # """ + + # @abstractmethod + # def set_mlcube_association_approval( + # self, mlcube_uid: str, benchmark_uid: str, status: str + # ): + # """Approves an mlcube association + + # Args: + # mlcube_uid (str): Dataset UID + # benchmark_uid (str): Benchmark UID + # status (str): Approval status to set for the association + # """ + + # @abstractmethod + # def get_datasets_associations(self) -> List[dict]: + # """Get all dataset associations related to the current user + + # Returns: + # List[dict]: List containing all associations information + # """ + + # @abstractmethod + # def get_cubes_associations(self) -> List[dict]: + # """Get all cube associations related to the current user + + # Returns: + # List[dict]: List containing all associations information + # """ + + # @abstractmethod + # def set_mlcube_association_priority( + # self, benchmark_uid: str, mlcube_uid: str, priority: int + # ): + # """Sets the priority of an mlcube-benchmark association + + # Args: + # mlcube_uid (str): MLCube UID + # benchmark_uid (str): Benchmark UID + # priority (int): priority value to set for the association + # """ + + # @abstractmethod + # def update_dataset(self, dataset_id: int, data: dict): + # """Updates the contents of a datasets identified by dataset_id to the new data dictionary. + # Updates may be partial. + + # Args: + # dataset_id (int): ID of the dataset to update + # data (dict): Updated information of the dataset. + # """ diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 6783230c8..2927aff92 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -222,7 +222,7 @@ def get_training_exp(self, training_exp_id: int) -> dict: Returns: dict: training_exp specification """ - url = f"{self.server_url}/training/{training_exp_id}" + url = f"{self.server_url}/training/{training_exp_id}/" error_msg = "Could not retrieve training experiment" return self.__get(url, error_msg) @@ -706,7 +706,7 @@ def associate_training_ca(self, ca_id: int, training_exp_id: int): """ url = f"{self.server_url}/cas/training/" data = { - "aggregator": ca_id, + "ca": ca_id, "training_exp": training_exp_id, "approval_status": Status.PENDING.value, } diff --git a/cli/medperf/config.py b/cli/medperf/config.py index b870959a2..09e5179bc 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -64,6 +64,8 @@ tests_folder = "tests" training_folder = "training" aggregators_folder = "aggregators" +cas_folder = "cas" +training_events_folder = "training_events" default_base_storage = str(Path.home().resolve() / ".medperf") @@ -120,6 +122,14 @@ "base": default_base_storage, "name": aggregators_folder, }, + "cas_folder": { + "base": default_base_storage, + "name": cas_folder, + }, + "training_events_folder": { + "base": default_base_storage, + "name": training_events_folder, + }, } root_folders = [ @@ -138,6 +148,8 @@ "tests_folder", "training_folder", "aggregators_folder", + "cas_folder", + "training_events_folder", ] # MedPerf filenames conventions @@ -146,6 +158,8 @@ test_report_file = "test_report.yaml" reg_file = "registration-info.yaml" agg_file = "agg-info.yaml" +ca_file = "ca-info.yaml" +training_event_file = "event.yaml" cube_metadata_filename = "mlcube-meta.yaml" log_file = "medperf.log" log_package_file = "medperf_logs.tar.gz" @@ -156,7 +170,8 @@ participants_list_filename = "cols.yaml" training_exp_plan_filename = "plan.yaml" training_report_file = "report.yaml" -training_out_logs = "logs" +training_out_agg_logs = "agg_logs" +training_out_col_logs = "col_logs" training_out_weights = "weights" ca_cert_folder = "ca_cert" ca_config_file = "ca_config.json" diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index b327417e2..4451fe1aa 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -249,7 +249,13 @@ def run( kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" cmd += f" --mlcube={self.cube_path} --task={task} --platform={config.platform}" - if task not in ["train", "start_aggregator"]: + if task not in [ + "train", + "start_aggregator", + "trust", + "get_client_cert", + "get_server_cert", + ]: cmd += " --network=none" if config.gpus is not None: cmd += f" --gpus={config.gpus}" diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py index 484ae4767..dce943aa0 100644 --- a/cli/medperf/entities/event.py +++ b/cli/medperf/entities/event.py @@ -27,7 +27,7 @@ class TrainingEvent(Entity, MedperfSchema): participants: dict finished: bool = False finished_at: Optional[datetime] - report: dict = {} + report: Optional[dict] @staticmethod def get_type(): @@ -56,8 +56,12 @@ def __init__(self, *args, **kwargs): self.participants_list_path = os.path.join( self.path, config.participants_list_filename ) - self.out_logs = os.path.join(self.path, config.training_out_logs) + self.agg_out_logs = os.path.join(self.path, config.training_out_agg_logs) + self.col_out_logs = os.path.join(self.path, config.training_out_col_logs) self.out_weights = os.path.join(self.path, config.training_out_weights) + + # TODO: move this into a subfolder, since participants list file is in the same folder + # which means this folder will be mounted read-only self.report_path = os.path.join(self.path, config.training_report_file) @classmethod diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 7d8042fde..103490ac0 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -492,4 +492,5 @@ def get_pki_assets_path(common_name: str, ca_name: str): def get_participant_label(email, data_id): - return f"{email}_d{data_id}" + # return f"d{data_id}" # TODO: use this when building openfl fork + return f"{email}" diff --git a/cli/requirements.txt b/cli/requirements.txt index 037126d1d..02d8ee05a 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -10,9 +10,9 @@ colorama==0.4.4 time-machine==2.4.0 pytest-mock==1.13.0 pyfakefs==5.0.0 -mlcube @ git+https://github.com/hasan7n/mlcube@11632a85064f653e6c5a59e3c6a4996ab9fe510b#subdirectory=mlcube -mlcube-docker @ git+https://github.com/hasan7n/mlcube@11632a85064f653e6c5a59e3c6a4996ab9fe510b#subdirectory=runners/mlcube_docker -mlcube-singularity @ git+https://github.com/hasan7n/mlcube@11632a85064f653e6c5a59e3c6a4996ab9fe510b#subdirectory=runners/mlcube_singularity +mlcube @ git+https://github.com/hasan7n/mlcube@7fbb1828c9c79ab52023d2919e157468deecb95a#subdirectory=mlcube +mlcube-docker @ git+https://github.com/hasan7n/mlcube@7fbb1828c9c79ab52023d2919e157468deecb95a#subdirectory=runners/mlcube_docker +mlcube-singularity @ git+https://github.com/hasan7n/mlcube@7fbb1828c9c79ab52023d2919e157468deecb95a#subdirectory=runners/mlcube_singularity validators==0.18.2 merge-args==0.1.4 synapseclient==4.1.1 diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index d10265c07..468248fb4 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -108,8 +108,7 @@ METRIC_MLCUBE="$ASSETS_URL/metrics/mlcube/mlcube.yaml" METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" # FL cubes -TRAIN_MLCUBE="https://storage.googleapis.com/medperf-storage/testfl/mlcube1.0.0.yaml" -TRAIN_PARAMS="https://storage.googleapis.com/medperf-storage/testfl/parameters1.0.0.yaml" +TRAIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/fl/mlcube/mlcube.yaml" TRAIN_WEIGHTS="https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz" # test users credentials @@ -124,3 +123,5 @@ AGGOWNER="testao@example.com" PREP_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/data_preparator/mlcube" MODEL_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/model_custom_cnn/mlcube" METRIC_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/metrics/mlcube" + +TRAINING_CONFIG="$(dirname $(dirname $(realpath "$0")))/examples/fl/fl/mlcube/workspace/training_config.yaml" diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py index 76db41459..4cee2ba39 100644 --- a/examples/fl/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -48,18 +48,12 @@ def prepare_plan(plan_path, fl_workspace): def prepare_cols_list(collaborators_file, fl_workspace): with open(collaborators_file) as f: - cols = f.read().strip().split("\n") - cols = [col.strip().split(",") for col in cols] - cols_dict = {} + cols_dict = yaml.safe_load(f) cn_different = False - for col in cols: - if len(col) == 1: - cols_dict[col[0]] = col[0] - else: - assert len(col) == 2 - cols_dict[col[0]] = col[1] - if col[0] != col[1]: - cn_different = True + for col_label in cols_dict.keys(): + cn = cols_dict[col_label] + if cn != col_label: + cn_different = True if not cn_different: # quick hack to support old and new openfl versions cols_dict = list(cols_dict.keys()) diff --git a/examples/fl/fl/setup_test.sh b/examples/fl/fl/setup_test.sh index 75a3d68f5..542dd7164 100644 --- a/examples/fl/fl/setup_test.sh +++ b/examples/fl/fl/setup_test.sh @@ -31,72 +31,41 @@ mkdir ./ca HOSTNAME_=$(hostname -A | cut -d " " -f 1) -# root ca -openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 -openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ - -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" +medperf mlcube run --mlcube ../mock_cert/mlcube --task trust +mv ../mock_cert/mlcube/workspace/pki_assets/* ./ca # col1 -sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf -cd mlcube_col1/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd ../../../ +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL1_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col1/workspace/node_cert +cp -r ./ca/* ./mlcube_col1/workspace/ca_cert # col2 -sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf -cd mlcube_col2/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd ../../../ +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL2_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col2/workspace/node_cert +cp -r ./ca/* ./mlcube_col2/workspace/ca_cert # col3 if ${TWO_COL_SAME_CERT}; then cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert else - sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf - sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf - cd mlcube_col3/workspace/node_cert - openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 - openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client - openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 - rm csr.csr - cp ../../../ca/root.crt ../ca_cert/ - cd ../../../ + medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL3_CN + mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col3/workspace/node_cert + cp -r ./ca/* ./mlcube_col3/workspace/ca_cert fi -# agg -sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf -cd mlcube_agg/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd ../../../ +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_server_cert -e MEDPERF_INPUT_CN=$HOSTNAME_ +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_agg/workspace/node_cert +cp -r ./ca/* ./mlcube_agg/workspace/ca_cert # aggregator_config echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml # cols file -echo "$COL1_LABEL,$COL1_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL2_LABEL,$COL2_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL3_LABEL,$COL3_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml # data download cd mlcube_col1/workspace/ diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh new file mode 100644 index 000000000..915fe85a8 --- /dev/null +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -0,0 +1,124 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +# root ca +openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 +openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +# col1 +sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf +cd mlcube_col1/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col2 +sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf +cd mlcube_col2/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf + cd mlcube_col3/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ +fi + +# agg +sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf +cd mlcube_agg/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace + +# weights download +cd mlcube_agg/workspace/ +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz +tar -xf init_weights_miccai.tar.gz +rm init_weights_miccai.tar.gz +cd ../../.. diff --git a/examples/fl/mock_cert/project/Dockerfile b/examples/fl/mock_cert/project/Dockerfile index 91c477415..cf625ca6b 100644 --- a/examples/fl/mock_cert/project/Dockerfile +++ b/examples/fl/mock_cert/project/Dockerfile @@ -8,4 +8,6 @@ ENV LANG C.UTF-8 COPY . /mlcube_project +RUN chmod a+r /mlcube_project/ca/root.key + ENTRYPOINT ["python3", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/mock_cert/project/sign.sh b/examples/fl/mock_cert/project/sign.sh index ffadc5144..2b2fd3727 100644 --- a/examples/fl/mock_cert/project/sign.sh +++ b/examples/fl/mock_cert/project/sign.sh @@ -17,9 +17,12 @@ if [ -z "$MEDPERF_INPUT_CN" ]; then exit 1 fi -CSR_TEMPLATE=/mlcube_project/csr.conf -CA_KEY=/mlcube_project/ca/root.key -CA_CERT=/mlcube_project/ca/cert/root.crt +mkdir -p $OUT +cp /mlcube_project/csr.conf $OUT/ +cp -r /mlcube_project/ca $OUT/ +CSR_TEMPLATE=$OUT/csr.conf +CA_KEY=$OUT/ca/root.key +CA_CERT=$OUT/ca/cert/root.crt sed -i "/^commonName = /c\commonName = $MEDPERF_INPUT_CN" $CSR_TEMPLATE sed -i "/^DNS\.1 = /c\DNS.1 = $MEDPERF_INPUT_CN" $CSR_TEMPLATE @@ -29,3 +32,5 @@ openssl req -new -key $OUT/key.key -out $OUT/csr.csr -config $CSR_TEMPLATE -exte openssl x509 -req -in $OUT/csr.csr -CA $CA_CERT -CAkey $CA_KEY \ -CAcreateserial -out $OUT/crt.crt -days 36500 -sha384 rm $OUT/csr.csr +rm -r $OUT/ca +rm -r $OUT/csr.conf diff --git a/examples/fl/mock_cert/test.sh b/examples/fl/mock_cert/test.sh index eaa800edf..79eaa584f 100644 --- a/examples/fl/mock_cert/test.sh +++ b/examples/fl/mock_cert/test.sh @@ -1,6 +1,6 @@ -mlcube run --mlcube ./mlcube/mlcube.yaml --task trust -# sh clean.sh -mlcube run --mlcube ./mlcube/mlcube.yaml --task get_client_cert -Pdocker.env_args="-e MEDPERF_INPUT_CN=user@example.com" +medperf mlcube run --mlcube ./mlcube --task trust sh clean.sh -mlcube run --mlcube ./mlcube/mlcube.yaml --task get_server_cert -Pdocker.env_args="-e MEDPERF_INPUT_CN=https://example.com" +medperf mlcube run --mlcube ./mlcube --task get_client_cert -e MEDPERF_INPUT_CN=user@example.com +sh clean.sh +medperf mlcube run --mlcube ./mlcube --task get_server_cert -e MEDPERF_INPUT_CN=https://example.com sh clean.sh diff --git a/server/.env.local.local-auth b/server/.env.local.local-auth index e0f3acfe4..681d93918 100644 --- a/server/.env.local.local-auth +++ b/server/.env.local.local-auth @@ -28,4 +28,4 @@ CA_CONFIG={"address":"https://127.0.0.1","port":443,"fingerprint":"fingerprint", CA_MLCUBE_NAME="MedPerf CA" CA_MLCUBE_URL="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/mock_cert/mlcube/mlcube.yaml" CA_MLCUBE_HASH="d3d723fa6e14ea5f3ff1b215c4543295271bebf301d113c4953c5d54310b7dd1" -CA_MLCUBE_IMAGE_HASH="12da9239869a629b9c4fb8c04773219b74efcbeb48380065a0eba6c4f716c122" \ No newline at end of file +CA_MLCUBE_IMAGE_HASH="24606b9b48e0da09b01d76870f02a1e24b5128798a6c4b65aae17ab85cb16208" \ No newline at end of file diff --git a/server/training/models.py b/server/training/models.py index 0726ed87f..5119a5e8c 100644 --- a/server/training/models.py +++ b/server/training/models.py @@ -57,14 +57,14 @@ def event(self): @property def aggregator(self): aggregator_assoc = ( - self.aggregator_association_set.all().order_by("created_at").last() + self.experimentaggregator_set.all().order_by("created_at").last() ) if aggregator_assoc and aggregator_assoc.approval_status == "APPROVED": return aggregator_assoc.aggregator @property def ca(self): - ca_assoc = self.ca_association_set.all().order_by("created_at").last() + ca_assoc = self.experimentca_set.all().order_by("created_at").last() if ca_assoc and ca_assoc.approval_status == "APPROVED": return ca_assoc.ca diff --git a/server/training/views.py b/server/training/views.py index a2c13edac..0a9a75d54 100644 --- a/server/training/views.py +++ b/server/training/views.py @@ -59,7 +59,9 @@ def post(self, request, format=None): class TrainingAggregator(GenericAPIView): - permission_classes = [IsAdmin | IsExpOwner | IsAssociatedDatasetOwner] + permission_classes = [ + IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + ] serializer_class = AggregatorSerializer queryset = "" @@ -99,7 +101,7 @@ def get(self, request, pk, format=None): Retrieve datasets associated with a training experiment instance. """ training_exp = self.get_object(pk) - datasets = training_exp.traindataset_association_set.all() + datasets = training_exp.experimentdataset_set.all() datasets = self.paginate_queryset(datasets) serializer = TrainingExperimentListofDatasetsSerializer(datasets, many=True) return self.get_paginated_response(serializer.data) @@ -225,7 +227,7 @@ def get(self, request, pk, format=None): """ training_exp = self.get_object(pk) latest_datasets_assocs_status = ( - training_exp.traindataset_association_set.all() + training_exp.experimentdataset_set.all() .filter(dataset__id=OuterRef("id")) .order_by("-created_at")[:1] .values("approval_status") diff --git a/server/trainingevent/permissions.py b/server/trainingevent/permissions.py index f0eb1ee94..1bb0c4b0d 100644 --- a/server/trainingevent/permissions.py +++ b/server/trainingevent/permissions.py @@ -1,5 +1,5 @@ from rest_framework.permissions import BasePermission -from .models import TrainingExperiment +from .models import TrainingEvent, TrainingExperiment class IsAdmin(BasePermission): @@ -14,8 +14,24 @@ def get_object(self, tid): except TrainingExperiment.DoesNotExist: return None + def get_event_object(self, pk): + try: + return TrainingEvent.objects.get(pk=pk) + except TrainingEvent.DoesNotExist: + return None + def has_permission(self, request, view): - tid = view.kwargs.get("tid", None) + if request.method == "POST": + tid = request.data.get("training_exp", None) + else: + pk = view.kwargs.get("pk", None) + if not pk: + return False + event = self.get_event_object(pk) + if not event: + return False + tid = event.training_exp.id + if not tid: return False training_exp = self.get_object(tid) @@ -28,22 +44,23 @@ def has_permission(self, request, view): class IsAggregatorOwner(BasePermission): - def get_object(self, tid): + def get_object(self, pk): try: - return TrainingExperiment.objects.get(pk=tid) - except TrainingExperiment.DoesNotExist: + return TrainingEvent.objects.get(pk=pk) + except TrainingEvent.DoesNotExist: return None def has_permission(self, request, view): - tid = view.kwargs.get("tid", None) - if not tid: + pk = view.kwargs.get("pk", None) + if not pk: return False - training_exp = self.get_object(tid) - if not training_exp: + event = self.get_object(pk) + if not event: return False - aggregator = training_exp.aggregator + aggregator = event.training_exp.aggregator if not aggregator: return False + if aggregator.owner.id == request.user.id: return True else: diff --git a/server/trainingevent/serializers.py b/server/trainingevent/serializers.py index 46f29243c..863f83688 100644 --- a/server/trainingevent/serializers.py +++ b/server/trainingevent/serializers.py @@ -1,6 +1,5 @@ from rest_framework import serializers from .models import TrainingEvent -from training.models import TrainingExperiment from django.utils import timezone @@ -11,7 +10,7 @@ class Meta: read_only_fields = ["finished", "finished_at", "report", "owner"] def validate(self, data): - training_exp = TrainingExperiment.objects.get(pk=data["training_exp"]) + training_exp = data["training_exp"] if training_exp.approval_status != "APPROVED": raise serializers.ValidationError( "User cannot create an event unless the experiment is approved" @@ -45,6 +44,7 @@ class Meta: "participants", "finished", "owner", + "name", ] def validate(self, data): From c4b3fa7e1b4709e3683630e5b401563d19c3c2f2 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 30 Apr 2024 16:20:55 +0200 Subject: [PATCH 051/242] use event's report within a subfolder --- cli/medperf/config.py | 1 + cli/medperf/entities/event.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 09e5179bc..85429682d 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -170,6 +170,7 @@ participants_list_filename = "cols.yaml" training_exp_plan_filename = "plan.yaml" training_report_file = "report.yaml" +training_report_folder = "report" training_out_agg_logs = "agg_logs" training_out_col_logs = "col_logs" training_out_weights = "weights" diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py index dce943aa0..0b971c4e1 100644 --- a/cli/medperf/entities/event.py +++ b/cli/medperf/entities/event.py @@ -59,10 +59,9 @@ def __init__(self, *args, **kwargs): self.agg_out_logs = os.path.join(self.path, config.training_out_agg_logs) self.col_out_logs = os.path.join(self.path, config.training_out_col_logs) self.out_weights = os.path.join(self.path, config.training_out_weights) - - # TODO: move this into a subfolder, since participants list file is in the same folder - # which means this folder will be mounted read-only - self.report_path = os.path.join(self.path, config.training_report_file) + self.report_path = os.path.join( + self.path, config.training_report_folder, config.training_report_file + ) @classmethod def from_experiment(cls, training_exp_uid: int) -> "TrainingEvent": From 0ee1b7f9ec45b2abc17550ed88f26f3fe1d15b9a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 30 Apr 2024 16:36:54 +0200 Subject: [PATCH 052/242] use IP address instead of hostname in integration tests --- cli/cli_tests_training.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 470e92b86..206a5eafd 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -143,8 +143,8 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -# HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out -HOSTNAME_=$(hostname -A | cut -d " " -f 1) +HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out +# HOSTNAME_=$(hostname -A | cut -d " " -f 1) medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) From be5a0ba3829b17d0a85f572902de03c782cefbda Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 30 Apr 2024 19:07:53 +0200 Subject: [PATCH 053/242] fix tests --- cli/cli_tests_training.sh | 14 ++++++++++++-- server/aggregator/models.py | 2 +- server/ca/models.py | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 206a5eafd..4d1c703ad 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -143,8 +143,8 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out -# HOSTNAME_=$(hostname -A | cut -d " " -f 1) +# HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out +HOSTNAME_=$(hostname -A | cut -d " " -f 1) medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) @@ -490,6 +490,16 @@ checkFailed "aggregator didn't exit successfully" echo "\n" +########################################################## +echo "=====================================" +echo "Activate modelowner profile" +echo "=====================================" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "close event" diff --git a/server/aggregator/models.py b/server/aggregator/models.py index 8947c626d..8f29fcb15 100644 --- a/server/aggregator/models.py +++ b/server/aggregator/models.py @@ -19,7 +19,7 @@ class Aggregator(models.Model): modified_at = models.DateTimeField(auto_now=True) def __str__(self): - return self.config + return str(self.config) class Meta: ordering = ["created_at"] diff --git a/server/ca/models.py b/server/ca/models.py index 1165731fb..5b3d62d30 100644 --- a/server/ca/models.py +++ b/server/ca/models.py @@ -23,7 +23,7 @@ class CA(models.Model): modified_at = models.DateTimeField(auto_now=True) def __str__(self): - return self.config + return str(self.config) class Meta: ordering = ["created_at"] From ad93f1fe6617fbd02b56d75034fee03ef880b9ac Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 00:55:14 +0200 Subject: [PATCH 054/242] remove comment --- cli/medperf/commands/dataset/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index bb67c09af..f22bb86c4 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -73,7 +73,7 @@ def prepare_plan(self): def prepare_pki_assets(self): ca = CA.from_experiment(self.training_exp_id) - # trust(ca) + trust(ca) self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) self.ca = ca From 9f52cd3b9bde24447b2ede030f9f596c6c46d517 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 02:10:46 +0200 Subject: [PATCH 055/242] add --overwrite flag for an easier life --- cli/medperf/commands/aggregator/aggregator.py | 5 +++- cli/medperf/commands/aggregator/run.py | 29 +++++++++++++++---- .../commands/certificate/certificate.py | 10 +++++-- .../certificate/client_certificate.py | 14 +++++---- .../certificate/server_certificate.py | 14 +++++---- cli/medperf/commands/dataset/dataset.py | 5 +++- cli/medperf/commands/dataset/train.py | 26 +++++++++++++---- 7 files changed, 78 insertions(+), 25 deletions(-) diff --git a/cli/medperf/commands/aggregator/aggregator.py b/cli/medperf/commands/aggregator/aggregator.py index f640d8f38..831a71356 100644 --- a/cli/medperf/commands/aggregator/aggregator.py +++ b/cli/medperf/commands/aggregator/aggregator.py @@ -58,9 +58,12 @@ def run( "-t", help="UID of training experiment whose aggregator to be run", ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite outputs if present" + ), ): """Starts the aggregation server of a training experiment""" - StartAggregator.run(training_exp_id) + StartAggregator.run(training_exp_id, overwrite) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index 03eedd616..319599e58 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -1,25 +1,27 @@ +import os from medperf import config from medperf.entities.ca import CA from medperf.entities.event import TrainingEvent -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import InvalidArgumentError, MedperfException from medperf.entities.training_exp import TrainingExp from medperf.entities.aggregator import Aggregator from medperf.entities.cube import Cube -from medperf.utils import get_pki_assets_path +from medperf.utils import get_pki_assets_path, remove_path from medperf.certificates import trust class StartAggregator: @classmethod - def run(cls, training_exp_id: int): + def run(cls, training_exp_id: int, overwrite: bool = False): """Starts the aggregation server of a training experiment Args: training_exp_id (int): Training experiment UID. """ - execution = cls(training_exp_id) + execution = cls(training_exp_id, overwrite) execution.prepare() execution.validate() + execution.check_existing_outputs() execution.prepare_aggregator() execution.prepare_participants_list() execution.prepare_plan() @@ -27,8 +29,9 @@ def run(cls, training_exp_id: int): with config.ui.interactive(): execution.run_experiment() - def __init__(self, training_exp_id) -> None: + def __init__(self, training_exp_id, overwrite) -> None: self.training_exp_id = training_exp_id + self.overwrite = overwrite self.ui = config.ui def prepare(self): @@ -41,6 +44,22 @@ def validate(self): msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) + def check_existing_outputs(self): + msg = ( + "Outputs still exist from previous runs. Overwrite" + " them by rerunning the command with --overwrite" + ) + paths = [ + self.event.agg_out_logs, + self.event.out_weights, + self.event.report_path, + ] + for path in paths: + if os.path.exists(path): + if not self.overwrite: + raise MedperfException(msg) + remove_path(path) + def prepare_aggregator(self): self.aggregator = Aggregator.from_experiment(self.training_exp_id) self.cube = self.__get_cube(self.aggregator.aggregation_mlcube, "aggregation") diff --git a/cli/medperf/commands/certificate/certificate.py b/cli/medperf/commands/certificate/certificate.py index 3a5e387fc..7125fbcc8 100644 --- a/cli/medperf/commands/certificate/certificate.py +++ b/cli/medperf/commands/certificate/certificate.py @@ -17,9 +17,12 @@ def get_client_certificate( "-t", help="UID of training exp which you intend to be a part of", ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite cert and key if present" + ), ): """get a client certificate""" - GetUserCertificate.run(training_exp_id) + GetUserCertificate.run(training_exp_id, overwrite) config.ui.print("✅ Done!") @@ -32,7 +35,10 @@ def get_server_certificate( "-t", help="UID of training exp which you intend to be a part of", ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite cert and key if present" + ), ): """get a server certificate""" - GetServerCertificate.run(training_exp_id) + GetServerCertificate.run(training_exp_id, overwrite) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/certificate/client_certificate.py b/cli/medperf/commands/certificate/client_certificate.py index 2103aca6c..8eace42c8 100644 --- a/cli/medperf/commands/certificate/client_certificate.py +++ b/cli/medperf/commands/certificate/client_certificate.py @@ -1,18 +1,22 @@ from medperf.entities.ca import CA from medperf.account_management import get_medperf_user_data from medperf.certificates import get_client_cert -from medperf.utils import get_pki_assets_path +from medperf.exceptions import MedperfException +from medperf.utils import get_pki_assets_path, remove_path import os class GetUserCertificate: @staticmethod - def run(training_exp_id: int): + def run(training_exp_id: int, overwrite: bool = False): """get user cert""" ca = CA.from_experiment(training_exp_id) email = get_medperf_user_data()["email"] output_path = get_pki_assets_path(email, ca.name) - if os.path.exists(output_path) and os.listdir(output_path): - # TODO? - raise ValueError("already") + if os.path.exists(output_path): + if not overwrite: + raise MedperfException( + "Cert and key already present. Rerun the command with --overwrite" + ) + remove_path(output_path) get_client_cert(ca, email, output_path) diff --git a/cli/medperf/commands/certificate/server_certificate.py b/cli/medperf/commands/certificate/server_certificate.py index 68e45b2d6..1e7a25db8 100644 --- a/cli/medperf/commands/certificate/server_certificate.py +++ b/cli/medperf/commands/certificate/server_certificate.py @@ -1,19 +1,23 @@ from medperf.entities.ca import CA from medperf.entities.aggregator import Aggregator from medperf.certificates import get_server_cert -from medperf.utils import get_pki_assets_path +from medperf.exceptions import MedperfException +from medperf.utils import get_pki_assets_path, remove_path import os class GetServerCertificate: @staticmethod - def run(training_exp_id: int): + def run(training_exp_id: int, overwrite: bool = False): """get server cert""" ca = CA.from_experiment(training_exp_id) aggregator = Aggregator.from_experiment(training_exp_id) address = aggregator.address output_path = get_pki_assets_path(address, ca.name) - if os.path.exists(output_path) and os.listdir(output_path): - # TODO? - raise ValueError("already") + if os.path.exists(output_path): + if not overwrite: + raise MedperfException( + "Cert and key already present. Rerun the command with --overwrite" + ) + remove_path(output_path) get_server_cert(ca, address, output_path) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index ed1e710f7..77f4e0431 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -154,9 +154,12 @@ def train( data_uid: int = typer.Option( ..., "--data_uid", "-d", help="Registered Dataset UID" ), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite outputs if present" + ), ): """Runs training""" - TrainingExecution.run(training_exp_id, data_uid) + TrainingExecution.run(training_exp_id, data_uid, overwrite) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index f22bb86c4..747c8ddd8 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -3,23 +3,23 @@ from medperf.account_management.account_management import get_medperf_user_data from medperf.entities.ca import CA from medperf.entities.event import TrainingEvent -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import InvalidArgumentError, MedperfException from medperf.entities.training_exp import TrainingExp from medperf.entities.dataset import Dataset from medperf.entities.cube import Cube -from medperf.utils import get_pki_assets_path, get_participant_label +from medperf.utils import get_pki_assets_path, get_participant_label, remove_path from medperf.certificates import trust class TrainingExecution: @classmethod - def run(cls, training_exp_id: int, data_uid: int): + def run(cls, training_exp_id: int, data_uid: int, overwrite: bool = False): """Starts the aggregation server of a training experiment Args: training_exp_id (int): Training experiment UID. """ - execution = cls(training_exp_id, data_uid) + execution = cls(training_exp_id, data_uid, overwrite) execution.prepare() execution.validate() execution.prepare_training_cube() @@ -28,9 +28,10 @@ def run(cls, training_exp_id: int, data_uid: int): with config.ui.interactive(): execution.run_experiment() - def __init__(self, training_exp_id: int, data_uid: int) -> None: + def __init__(self, training_exp_id: int, data_uid: int, overwrite: bool) -> None: self.training_exp_id = training_exp_id self.data_uid = data_uid + self.overwrite = overwrite self.ui = config.ui def prepare(self): @@ -39,6 +40,7 @@ def prepare(self): self.event = TrainingEvent.from_experiment(self.training_exp_id) self.dataset = Dataset.get(self.data_uid) self.user_email: str = get_medperf_user_data()["email"] + self.out_logs = os.path.join(self.event.col_out_logs, str(self.dataset.id)) def validate(self): if self.dataset.id is None: @@ -58,6 +60,18 @@ def validate(self): # msg = "The provided dataset is not associated." # raise InvalidArgumentError(msg) + def check_existing_outputs(self): + msg = ( + "Outputs still exist from previous runs. Overwrite" + " them by rerunning the command with --overwrite" + ) + paths = [self.out_logs] + for path in paths: + if os.path.exists(path): + if not self.overwrite: + raise MedperfException(msg) + remove_path(path) + def prepare_training_cube(self): self.cube = self.__get_cube(self.training_exp.fl_mlcube, "FL") @@ -86,7 +100,7 @@ def run_experiment(self): "node_cert_folder": self.dataset_pki_assets, "ca_cert_folder": self.ca.pki_assets, "plan_path": self.training_exp.plan_path, - "output_logs": os.path.join(self.event.col_out_logs, str(self.dataset.id)), + "output_logs": self.out_logs, } self.ui.text = "Running Training" From 3c7e9ea134c14b6bf5cc2f7f6959ddc2ccd23268 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 03:08:17 +0200 Subject: [PATCH 056/242] use abs path for training config --- cli/medperf/commands/training/set_plan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/medperf/commands/training/set_plan.py b/cli/medperf/commands/training/set_plan.py index 1c3e47365..0a959eb02 100644 --- a/cli/medperf/commands/training/set_plan.py +++ b/cli/medperf/commands/training/set_plan.py @@ -30,7 +30,7 @@ def run( def __init__(self, training_exp_id: int, training_config_path: str, approval: bool): self.ui = config.ui self.training_exp_id = training_exp_id - self.training_config_path = training_config_path + self.training_config_path = os.path.abspath(training_config_path) self.approved = approval self.plan_out_path = generate_tmp_path() From ce203252daaa96a78386d10431e9c711e9eb1a13 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 03:10:01 +0200 Subject: [PATCH 057/242] add missing call --- cli/medperf/commands/dataset/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index 747c8ddd8..9c1fc9879 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -22,6 +22,7 @@ def run(cls, training_exp_id: int, data_uid: int, overwrite: bool = False): execution = cls(training_exp_id, data_uid, overwrite) execution.prepare() execution.validate() + execution.check_existing_outputs() execution.prepare_training_cube() execution.prepare_plan() execution.prepare_pki_assets() From 5b7da4141c8526ccb3f2964af2c7d0a52f5369fa Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 13:33:32 +0200 Subject: [PATCH 058/242] REVERT ME --- cli/medperf/commands/training/start_event.py | 5 ++++- cli/medperf/entities/event.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py index e24c2fa48..adba67ed6 100644 --- a/cli/medperf/commands/training/start_event.py +++ b/cli/medperf/commands/training/start_event.py @@ -36,9 +36,12 @@ def create_participants_list(self): participant_common_name = user_email participants_list[participant_label] = participant_common_name self.participants_list = participants_list + self.participants_list = "\n".join( + [f"{p},{participants_list[p]}" for p in participants_list] + ) def submit(self): - dict_pretty_print(self.participants_list) + print(self.participants_list) msg = ( f"You are about to start an event for the training experiment {self.training_exp.name}." " This is the list of participants (participant label, participant common name)" diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py index 0b971c4e1..52e7beea3 100644 --- a/cli/medperf/entities/event.py +++ b/cli/medperf/entities/event.py @@ -24,7 +24,7 @@ class TrainingEvent(Entity, MedperfSchema): """ training_exp: int - participants: dict + participants: str finished: bool = False finished_at: Optional[datetime] report: Optional[dict] @@ -87,7 +87,7 @@ def _Entity__remote_prefilter(cls, filters: dict) -> callable: def prepare_participants_list(self): with open(self.participants_list_path, "w") as f: - yaml.dump(self.participants, f) + f.write(self.participants) def display_dict(self): return { From b21ef08d774ba1c3f908609845dc3543bb1af599 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 14:46:03 +0200 Subject: [PATCH 059/242] remove trailing equals --- cli/medperf/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 103490ac0..d87870020 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -488,6 +488,7 @@ def get_pki_assets_path(common_name: str, ca_name: str): # Base64 encoding is used just to avoid special characters used in emails # and server domains/ipaddresses. cn_encoded = base64.b64encode(common_name.encode("utf-8")).decode("utf-8") + cn_encoded = cn_encoded.rstrip("=") return os.path.join(config.pki_assets, cn_encoded, ca_name) From fd8054163eb3ce71923b888f69f740da8feca77e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 18:18:22 +0200 Subject: [PATCH 060/242] fix mock certs issue --- examples/fl/fl/csr.conf | 20 ++++++++++++++------ examples/fl/fl/setup_test_no_docker.sh | 8 ++++---- examples/fl/mock_cert/project/csr.conf | 20 ++++++++++++++------ examples/fl/mock_cert/project/sign.sh | 10 +++++----- server/.env.local.local-auth | 2 +- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/examples/fl/fl/csr.conf b/examples/fl/fl/csr.conf index 3285aed9f..c3b2d0f0c 100644 --- a/examples/fl/fl/csr.conf +++ b/examples/fl/fl/csr.conf @@ -3,21 +3,29 @@ default_bits = 3072 prompt = no default_md = sha384 distinguished_name = req_distinguished_name -req_extensions = req_ext [ req_distinguished_name ] commonName = hasan-hp-zbook-15-g3.home -[ req_ext ] -basicConstraints = critical,CA:FALSE -keyUsage = critical,digitalSignature,keyEncipherment -subjectAltName = @alt_names - [ alt_names ] DNS.1 = hasan-hp-zbook-15-g3.home [ v3_client ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names extendedKeyUsage = critical,clientAuth [ v3_server ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names extendedKeyUsage = critical,serverAuth + +[ v3_client_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names + +[ v3_server_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh index 915fe85a8..879e84ced 100644 --- a/examples/fl/fl/setup_test_no_docker.sh +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -43,7 +43,7 @@ cd mlcube_col1/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf rm csr.csr cp ../../../ca/root.crt ../ca_cert/ cd ../../../ @@ -55,7 +55,7 @@ cd mlcube_col2/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf rm csr.csr cp ../../../ca/root.crt ../ca_cert/ cd ../../../ @@ -71,7 +71,7 @@ else openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf rm csr.csr cp ../../../ca/root.crt ../ca_cert/ cd ../../../ @@ -84,7 +84,7 @@ cd mlcube_agg/workspace/node_cert openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_server_crt -extfile ../../../csr.conf rm csr.csr cp ../../../ca/root.crt ../ca_cert/ cd ../../../ diff --git a/examples/fl/mock_cert/project/csr.conf b/examples/fl/mock_cert/project/csr.conf index 3285aed9f..c3b2d0f0c 100644 --- a/examples/fl/mock_cert/project/csr.conf +++ b/examples/fl/mock_cert/project/csr.conf @@ -3,21 +3,29 @@ default_bits = 3072 prompt = no default_md = sha384 distinguished_name = req_distinguished_name -req_extensions = req_ext [ req_distinguished_name ] commonName = hasan-hp-zbook-15-g3.home -[ req_ext ] -basicConstraints = critical,CA:FALSE -keyUsage = critical,digitalSignature,keyEncipherment -subjectAltName = @alt_names - [ alt_names ] DNS.1 = hasan-hp-zbook-15-g3.home [ v3_client ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names extendedKeyUsage = critical,clientAuth [ v3_server ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names extendedKeyUsage = critical,serverAuth + +[ v3_client_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names + +[ v3_server_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names diff --git a/examples/fl/mock_cert/project/sign.sh b/examples/fl/mock_cert/project/sign.sh index 2b2fd3727..4295351df 100644 --- a/examples/fl/mock_cert/project/sign.sh +++ b/examples/fl/mock_cert/project/sign.sh @@ -20,17 +20,17 @@ fi mkdir -p $OUT cp /mlcube_project/csr.conf $OUT/ cp -r /mlcube_project/ca $OUT/ -CSR_TEMPLATE=$OUT/csr.conf +CSR_CONF=$OUT/csr.conf CA_KEY=$OUT/ca/root.key CA_CERT=$OUT/ca/cert/root.crt -sed -i "/^commonName = /c\commonName = $MEDPERF_INPUT_CN" $CSR_TEMPLATE -sed -i "/^DNS\.1 = /c\DNS.1 = $MEDPERF_INPUT_CN" $CSR_TEMPLATE +sed -i "/^commonName = /c\commonName = $MEDPERF_INPUT_CN" $CSR_CONF +sed -i "/^DNS\.1 = /c\DNS.1 = $MEDPERF_INPUT_CN" $CSR_CONF openssl genpkey -algorithm RSA -out $OUT/key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key $OUT/key.key -out $OUT/csr.csr -config $CSR_TEMPLATE -extensions $EXT +openssl req -new -key $OUT/key.key -out $OUT/csr.csr -config $CSR_CONF -extensions $EXT openssl x509 -req -in $OUT/csr.csr -CA $CA_CERT -CAkey $CA_KEY \ - -CAcreateserial -out $OUT/crt.crt -days 36500 -sha384 + -CAcreateserial -out $OUT/crt.crt -days 36500 -sha384 -extensions ${EXT}_crt -extfile $CSR_CONF rm $OUT/csr.csr rm -r $OUT/ca rm -r $OUT/csr.conf diff --git a/server/.env.local.local-auth b/server/.env.local.local-auth index 681d93918..db4449d2d 100644 --- a/server/.env.local.local-auth +++ b/server/.env.local.local-auth @@ -28,4 +28,4 @@ CA_CONFIG={"address":"https://127.0.0.1","port":443,"fingerprint":"fingerprint", CA_MLCUBE_NAME="MedPerf CA" CA_MLCUBE_URL="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/mock_cert/mlcube/mlcube.yaml" CA_MLCUBE_HASH="d3d723fa6e14ea5f3ff1b215c4543295271bebf301d113c4953c5d54310b7dd1" -CA_MLCUBE_IMAGE_HASH="24606b9b48e0da09b01d76870f02a1e24b5128798a6c4b65aae17ab85cb16208" \ No newline at end of file +CA_MLCUBE_IMAGE_HASH="48a16a6b1b42aed79741abf5a799b309feac7f2b4ccb7a8ac89a0fccfc6dd691" \ No newline at end of file From 192767c570a0c74575f0654c735a5b3b1fe87935 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 1 May 2024 18:25:52 +0200 Subject: [PATCH 061/242] add download run files --- cli/medperf/certificates.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cli/medperf/certificates.py b/cli/medperf/certificates.py index dfb9afc9c..5c595a378 100644 --- a/cli/medperf/certificates.py +++ b/cli/medperf/certificates.py @@ -13,6 +13,7 @@ def get_client_cert(ca: CA, email: str, output_path: str): env = {"MEDPERF_INPUT_CN": common_name} mlcube = Cube.get(ca.client_mlcube) + mlcube.download_run_files() mlube_task = "get_client_cert" mlcube.run(task=mlube_task, env_dict=env, **params) @@ -28,6 +29,7 @@ def get_server_cert(ca: CA, address: str, output_path: str): env = {"MEDPERF_INPUT_CN": common_name} mlcube = Cube.get(ca.server_mlcube) + mlcube.download_run_files() mlube_task = "get_server_cert" mlcube.run(task=mlube_task, env_dict=env, port=80, **params) @@ -43,5 +45,6 @@ def trust(ca: CA): "pki_assets": ca.pki_assets, } mlcube = Cube.get(ca.ca_mlcube) + mlcube.download_run_files() mlube_task = "trust" mlcube.run(task=mlube_task, **params) From a3e3a543828be373aec748f738a5af139c5a347d Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 2 May 2024 01:42:02 +0200 Subject: [PATCH 062/242] Revert "REVERT ME" This reverts commit 5b7da4141c8526ccb3f2964af2c7d0a52f5369fa. --- cli/medperf/commands/training/start_event.py | 5 +---- cli/medperf/entities/event.py | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py index adba67ed6..e24c2fa48 100644 --- a/cli/medperf/commands/training/start_event.py +++ b/cli/medperf/commands/training/start_event.py @@ -36,12 +36,9 @@ def create_participants_list(self): participant_common_name = user_email participants_list[participant_label] = participant_common_name self.participants_list = participants_list - self.participants_list = "\n".join( - [f"{p},{participants_list[p]}" for p in participants_list] - ) def submit(self): - print(self.participants_list) + dict_pretty_print(self.participants_list) msg = ( f"You are about to start an event for the training experiment {self.training_exp.name}." " This is the list of participants (participant label, participant common name)" diff --git a/cli/medperf/entities/event.py b/cli/medperf/entities/event.py index 52e7beea3..0b971c4e1 100644 --- a/cli/medperf/entities/event.py +++ b/cli/medperf/entities/event.py @@ -24,7 +24,7 @@ class TrainingEvent(Entity, MedperfSchema): """ training_exp: int - participants: str + participants: dict finished: bool = False finished_at: Optional[datetime] report: Optional[dict] @@ -87,7 +87,7 @@ def _Entity__remote_prefilter(cls, filters: dict) -> callable: def prepare_participants_list(self): with open(self.participants_list_path, "w") as f: - f.write(self.participants) + yaml.dump(self.participants, f) def display_dict(self): return { From d0ac9ec7076fa7ed7516c9a53790e2466c0c4b85 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 2 May 2024 01:42:57 +0200 Subject: [PATCH 063/242] use ip address in integration tests --- cli/cli_tests_training.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 4d1c703ad..332a0dfa1 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -143,8 +143,8 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -# HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out -HOSTNAME_=$(hostname -A | cut -d " " -f 1) +HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out +# HOSTNAME_=$(hostname -A | cut -d " " -f 1) medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) From cab620d8eeab15297db3803c916d0a09cda3ec3e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 6 May 2024 03:39:53 +0200 Subject: [PATCH 064/242] add option to choose interface for publishing port --- cli/cli_tests_training.sh | 9 +++++--- cli/medperf/commands/aggregator/aggregator.py | 8 ++++++- cli/medperf/commands/aggregator/run.py | 22 +++++++++++++++---- cli/medperf/entities/cube.py | 10 +++++++-- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 332a0dfa1..7684fd48c 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -143,8 +143,8 @@ echo "\n" echo "=====================================" echo "Running aggregator submission step" echo "=====================================" -HOSTNAME_=$(hostname -I | cut -d " " -f 1) # todo: figure this out -# HOSTNAME_=$(hostname -A | cut -d " " -f 1) +HOSTNAME_=$(hostname -I | cut -d " " -f 1) +# HOSTNAME_=$(hostname -A | cut -d " " -f 1) # fqdn on github CI runner doesn't resolve from inside containers medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) @@ -406,11 +406,14 @@ checkFailed "testagg profile activation failed" echo "\n" +TRAINING_UID=1 +DSET_1_UID=1 +DSET_2_UID=2 ########################################################## echo "=====================================" echo "Starting aggregator" echo "=====================================" -medperf aggregator start -t $TRAINING_UID agg.log 2>&1 & +medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ agg.log 2>&1 & AGG_PID=$! # sleep so that the mlcube is run before we change profiles diff --git a/cli/medperf/commands/aggregator/aggregator.py b/cli/medperf/commands/aggregator/aggregator.py index 831a71356..54775f627 100644 --- a/cli/medperf/commands/aggregator/aggregator.py +++ b/cli/medperf/commands/aggregator/aggregator.py @@ -58,12 +58,18 @@ def run( "-t", help="UID of training experiment whose aggregator to be run", ), + publish_on: str = typer.Option( + "127.0.0.1", + "--publish_on", + "-p", + help="Host network interface on which the aggregator will listen", + ), overwrite: bool = typer.Option( False, "--overwrite", help="Overwrite outputs if present" ), ): """Starts the aggregation server of a training experiment""" - StartAggregator.run(training_exp_id, overwrite) + StartAggregator.run(training_exp_id, publish_on, overwrite) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index 319599e58..6207b6095 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -12,13 +12,13 @@ class StartAggregator: @classmethod - def run(cls, training_exp_id: int, overwrite: bool = False): + def run(cls, training_exp_id: int, publish_on: str, overwrite: bool = False): """Starts the aggregation server of a training experiment Args: training_exp_id (int): Training experiment UID. """ - execution = cls(training_exp_id, overwrite) + execution = cls(training_exp_id, publish_on, overwrite) execution.prepare() execution.validate() execution.check_existing_outputs() @@ -29,9 +29,10 @@ def run(cls, training_exp_id: int, overwrite: bool = False): with config.ui.interactive(): execution.run_experiment() - def __init__(self, training_exp_id, overwrite) -> None: + def __init__(self, training_exp_id, publish_on, overwrite) -> None: self.training_exp_id = training_exp_id self.overwrite = overwrite + self.publish_on = publish_on self.ui = config.ui def prepare(self): @@ -43,6 +44,14 @@ def validate(self): if self.event.finished: msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) + if self.publish_on == "127.0.0.1": + pass + # config.ui.print_warning("This has a bug...TODO") + # TODO: take confirmation somewhere about the whole process + # TODO: We should start checking inputs before proceeding. For example, + # now if the user provided some malformed network interface, this + # will not throw an error until many calls to the server has been made + # and things are configured... def check_existing_outputs(self): msg = ( @@ -96,4 +105,9 @@ def run_experiment(self): } self.ui.text = "Running Aggregator" - self.cube.run(task="start_aggregator", port=self.aggregator.port, **params) + self.cube.run( + task="start_aggregator", + port=self.aggregator.port, + publish_on=self.publish_on, + **params, + ) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 4451fe1aa..b303102a2 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -232,6 +232,7 @@ def run( timeout: int = None, read_protected_input: bool = True, port=None, + publish_on=None, env_dict: dict = {}, **kwargs, ): @@ -283,8 +284,12 @@ def run( cpu_args = " ".join([cpu_args, "-u $(id -u):$(id -g)"]).strip() gpu_args = " ".join([gpu_args, "-u $(id -u):$(id -g)"]).strip() if port is not None: - cpu_args += f" -p {port}:{port}" - gpu_args += f" -p {port}:{port}" + if publish_on: + cpu_args += f" -p {publish_on}:{port}:{port}" + gpu_args += f" -p {publish_on}:{port}:{port}" + else: + cpu_args += f" -p {port}:{port}" + gpu_args += f" -p {port}:{port}" cmd += f' -Pdocker.cpu_args="{cpu_args}"' cmd += f' -Pdocker.gpu_args="{gpu_args}"' if env_args_string: # TODO: why MLCube UI is so brittle? @@ -299,6 +304,7 @@ def run( run_args += " " + env_args_string cmd += f' -Psingularity.run_args="{run_args}"' # TODO: check if ports are already exposed. Think if this is OK + # TODO: check about exposing to specific network interfaces # TODO: check if --env works # set image name in case of running docker image with singularity From 0b5e4c0a0b4aee38481e33ff41ac80d019f231bf Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 17 May 2024 09:07:58 +0200 Subject: [PATCH 065/242] remove some TODOs from the code --- cli/medperf/commands/aggregator/run.py | 6 ------ cli/medperf/commands/dataset/train.py | 5 ----- cli/medperf/commands/training/training.py | 2 +- cli/medperf/utils.py | 2 +- examples/fl/fl/project/utils.py | 7 ------- server/training/models.py | 1 - 6 files changed, 2 insertions(+), 21 deletions(-) diff --git a/cli/medperf/commands/aggregator/run.py b/cli/medperf/commands/aggregator/run.py index 6207b6095..0ba96f3a0 100644 --- a/cli/medperf/commands/aggregator/run.py +++ b/cli/medperf/commands/aggregator/run.py @@ -46,12 +46,6 @@ def validate(self): raise InvalidArgumentError(msg) if self.publish_on == "127.0.0.1": pass - # config.ui.print_warning("This has a bug...TODO") - # TODO: take confirmation somewhere about the whole process - # TODO: We should start checking inputs before proceeding. For example, - # now if the user provided some malformed network interface, this - # will not throw an error until many calls to the server has been made - # and things are configured... def check_existing_outputs(self): msg = ( diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index 9c1fc9879..7fda13447 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -56,11 +56,6 @@ def validate(self): msg = "The provided training experiment has to start a training event." raise InvalidArgumentError(msg) - # TODO: Do we need this? This basically would make participants list public to them - # if self.dataset.id not in TrainingExp.get_datasets_uids(self.training_exp_id): - # msg = "The provided dataset is not associated." - # raise InvalidArgumentError(msg) - def check_existing_outputs(self): msg = ( "Outputs still exist from previous runs. Overwrite" diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index f42adbf19..36328eb96 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -39,7 +39,7 @@ def submit( "description": description, "docs_url": docs_url, "fl_mlcube": fl_mlcube, - "demo_dataset_tarball_url": "link", # TODO later + "demo_dataset_tarball_url": "link", "demo_dataset_tarball_hash": "hash", "demo_dataset_generated_uid": "uid", "data_preparation_mlcube": prep_mlcube, diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index d87870020..00d06abd8 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -493,5 +493,5 @@ def get_pki_assets_path(common_name: str, ca_name: str): def get_participant_label(email, data_id): - # return f"d{data_id}" # TODO: use this when building openfl fork + # return f"d{data_id}" return f"{email}" diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py index 4cee2ba39..d92add606 100644 --- a/examples/fl/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -22,8 +22,6 @@ def get_aggregator_fqdn(fl_workspace): def get_collaborator_cn(): - # TODO: check if there is a way this can cause a collision/race condition - # TODO: from inside the file return os.environ["MEDPERF_PARTICIPANT_LABEL"] @@ -39,7 +37,6 @@ def get_weights_path(fl_workspace): def prepare_plan(plan_path, fl_workspace): target_plan_folder = os.path.join(fl_workspace, "plan") - # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "plan.yaml") @@ -59,7 +56,6 @@ def prepare_cols_list(collaborators_file, fl_workspace): cols_dict = list(cols_dict.keys()) target_plan_folder = os.path.join(fl_workspace, "plan") - # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "cols.yaml") with open(target_plan_file, "w") as f: @@ -79,7 +75,6 @@ def prepare_init_weights(input_weights, fl_workspace): target_weights_subpath = get_weights_path(fl_workspace)["init"] target_weights_path = os.path.join(fl_workspace, target_weights_subpath) target_weights_folder = os.path.dirname(target_weights_path) - # TODO: permissions os.makedirs(target_weights_folder, exist_ok=True) os.symlink(file, target_weights_path) @@ -105,7 +100,6 @@ def prepare_node_cert( cert_file = os.path.join(node_cert_folder, cert_file) target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) - # TODO: permissions os.makedirs(target_cert_folder, exist_ok=True) target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") @@ -125,7 +119,6 @@ def prepare_ca_cert(ca_cert_folder, fl_workspace): file = os.path.join(ca_cert_folder, file) target_ca_cert_folder = os.path.join(fl_workspace, "cert") - # TODO: permissions os.makedirs(target_ca_cert_folder, exist_ok=True) target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") diff --git a/server/training/models.py b/server/training/models.py index 5119a5e8c..a65653119 100644 --- a/server/training/models.py +++ b/server/training/models.py @@ -34,7 +34,6 @@ class TrainingExperiment(models.Model): ) metadata = models.JSONField(default=dict, blank=True, null=True) - # TODO: consider if we want to enable restarts and epochs/"fresh restarts" state = models.CharField(choices=STATES, max_length=100, default="DEVELOPMENT") is_valid = models.BooleanField(default=True) approval_status = models.CharField( From 4f730351cc24bf7a888237ed1a7adce9ab8b3008 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 17 May 2024 09:09:16 +0200 Subject: [PATCH 066/242] add ca server and client code --- .gitignore | 2 + examples/fl/cert/build.sh | 1 + examples/fl/cert/mlcube/mlcube.yaml | 6 +- .../fl/cert/mlcube/workspace/ca_config.json | 4 +- examples/fl/cert/project/Dockerfile | 12 ++- examples/fl/cert/project/get_cert.sh | 35 ++++++- examples/fl/cert/project/trust.sh | 18 +++- examples/fl/cert/test.sh | 6 ++ flca/Dockerfile.dev | 22 +++++ flca/Dockerfile.prod | 22 +++++ flca/README.md | 80 ++++++++++++++++ flca/cloudbuild.yaml | 45 +++++++++ flca/dev_assets/ca.json | 57 +++++++++++ flca/dev_assets/client.tpl | 10 ++ flca/dev_assets/db_config.json | 4 + flca/dev_assets/intermediate_ca.crt | 27 ++++++ flca/dev_assets/intermediate_ca_key | 42 ++++++++ flca/dev_assets/pwd.txt | 1 + flca/dev_assets/reverse_proxy.conf | 10 ++ flca/dev_assets/root_ca.crt | 26 +++++ flca/dev_assets/server.tpl | 10 ++ flca/dev_assets/settings.json | 11 +++ flca/dev_utils.py | 23 +++++ flca/entrypoint.sh | 13 +++ flca/manual_setup/README.md | 1 + flca/manual_setup/create_keys.sh | 17 ++++ flca/manual_setup/rsa_intermediate_ca.tpl | 12 +++ flca/manual_setup/rsa_root_ca.tpl | 12 +++ flca/setup.py | 96 +++++++++++++++++++ flca/utils.py | 49 ++++++++++ 30 files changed, 657 insertions(+), 17 deletions(-) create mode 100644 examples/fl/cert/build.sh create mode 100644 examples/fl/cert/test.sh create mode 100644 flca/Dockerfile.dev create mode 100644 flca/Dockerfile.prod create mode 100644 flca/README.md create mode 100644 flca/cloudbuild.yaml create mode 100644 flca/dev_assets/ca.json create mode 100644 flca/dev_assets/client.tpl create mode 100644 flca/dev_assets/db_config.json create mode 100644 flca/dev_assets/intermediate_ca.crt create mode 100644 flca/dev_assets/intermediate_ca_key create mode 100644 flca/dev_assets/pwd.txt create mode 100644 flca/dev_assets/reverse_proxy.conf create mode 100644 flca/dev_assets/root_ca.crt create mode 100644 flca/dev_assets/server.tpl create mode 100644 flca/dev_assets/settings.json create mode 100644 flca/dev_utils.py create mode 100644 flca/entrypoint.sh create mode 100644 flca/manual_setup/README.md create mode 100644 flca/manual_setup/create_keys.sh create mode 100644 flca/manual_setup/rsa_intermediate_ca.tpl create mode 100644 flca/manual_setup/rsa_root_ca.tpl create mode 100644 flca/setup.py create mode 100644 flca/utils.py diff --git a/.gitignore b/.gitignore index 26d4cdd2e..048016fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,5 @@ server/keys # exclude fl example !examples/fl/mock_cert/project/ca/root.key !examples/fl/mock_cert/project/ca/cert/root.crt +!flca/dev_assets/intermediate_ca.crt +!flca/dev_assets/root_ca.crt diff --git a/examples/fl/cert/build.sh b/examples/fl/cert/build.sh new file mode 100644 index 000000000..d56304274 --- /dev/null +++ b/examples/fl/cert/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/cert/mlcube/mlcube.yaml b/examples/fl/cert/mlcube/mlcube.yaml index 700782b00..5612dd379 100644 --- a/examples/fl/cert/mlcube/mlcube.yaml +++ b/examples/fl/cert/mlcube/mlcube.yaml @@ -16,21 +16,21 @@ docker: tasks: trust: - entrypoint: /bin/bash /mlcube_project/trust.sh + entrypoint: /bin/sh /mlcube_project/trust.sh trust parameters: inputs: ca_config: ca_config.json outputs: pki_assets: pki_assets/ get_client_cert: - entrypoint: /bin/bash /mlcube_project/get_cert.sh + entrypoint: /bin/sh /mlcube_project/get_cert.sh get_client_cert parameters: inputs: ca_config: ca_config.json outputs: pki_assets: pki_assets/ get_server_cert: - entrypoint: /bin/bash /mlcube_project/get_cert.sh + entrypoint: /bin/sh /mlcube_project/get_cert.sh get_server_cert parameters: inputs: ca_config: ca_config.json diff --git a/examples/fl/cert/mlcube/workspace/ca_config.json b/examples/fl/cert/mlcube/workspace/ca_config.json index fbf4b8696..b701d00f8 100644 --- a/examples/fl/cert/mlcube/workspace/ca_config.json +++ b/examples/fl/cert/mlcube/workspace/ca_config.json @@ -1,7 +1,7 @@ { - "address": "https://example.com", + "address": "https://flcerts.medperf.org", "port": 443, - "fingerprint": "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + "fingerprint": "", "client_provisioner": "auth0", "server_provisioner": "acme" } \ No newline at end of file diff --git a/examples/fl/cert/project/Dockerfile b/examples/fl/cert/project/Dockerfile index 9a36c1bd2..227625bee 100644 --- a/examples/fl/cert/project/Dockerfile +++ b/examples/fl/cert/project/Dockerfile @@ -1,7 +1,13 @@ -FROM smallstep/step-cli:0.26.1 +FROM python:3.11.9-alpine -RUN apt-get update && apt-get install jq-y +# update openssl to fix https://avd.aquasec.com/nvd/cve-2024-2511 +RUN apk update && apk add openssl=3.1.4-r6 jq + +ARG VERSION=0.26.1 +RUN wget https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step_linux_${VERSION}_amd64.tar.gz \ + && cp step_${VERSION}/bin/step /usr/bin COPY . /mlcube_project -ENTRYPOINT ["/bin/bash"] +ENTRYPOINT ["/bin/sh"] \ No newline at end of file diff --git a/examples/fl/cert/project/get_cert.sh b/examples/fl/cert/project/get_cert.sh index 83c8bac40..39aa899a8 100644 --- a/examples/fl/cert/project/get_cert.sh +++ b/examples/fl/cert/project/get_cert.sh @@ -38,6 +38,8 @@ CA_FINGERPRINT=$(jq -r '.fingerprint' $ca_config) CA_CLIENT_PROVISIONER=$(jq -r '.client_provisioner' $ca_config) CA_SERVER_PROVISIONER=$(jq -r '.server_provisioner' $ca_config) +export STEPPATH=$pki_assets/.step + if [ "$task" = "get_server_cert" ]; then PROVISIONER_ARGS="--provisioner $CA_SERVER_PROVISIONER" elif [ "$task" = "get_client_cert" ]; then @@ -50,13 +52,36 @@ fi cert_path=$pki_assets/crt.crt key_path=$pki_assets/key.key -# trust the CA. -step ca bootstrap --ca-url $CA_ADDRESS:$CA_PORT \ - --fingerprint $CA_FINGERPRINT +if [ -e $STEPPATH ]; then + echo ".step folder already exists" + exit 1 +fi + +if [ -e $cert_path ]; then + echo "cert file already exists" + exit 1 +fi + +if [ -e $key_path ]; then + echo "key file already exists" + exit 1 +fi + +if [ -n "$CA_FINGERPRINT" ]; then + # trust the CA. + step ca bootstrap --ca-url $CA_ADDRESS:$CA_PORT \ + --fingerprint $CA_FINGERPRINT + ROOT=$STEPPATH/certs/root_ca.crt +else + ROOT=/etc/ssl/certs/ca-certificates.crt +fi # generate private key and ask for a certificate -# $STEPPATH/certs/root_ca.crt is the path where step-ca stores the trusted ca cert by default step ca certificate --ca-url $CA_ADDRESS:$CA_PORT \ - --root $STEPPATH/certs/root_ca.crt \ + --root $ROOT \ + --kty=RSA \ $PROVISIONER_ARGS \ $MEDPERF_INPUT_CN $cert_path $key_path + +# cleanup +rm -rf $STEPPATH diff --git a/examples/fl/cert/project/trust.sh b/examples/fl/cert/project/trust.sh index 7c9ecd4b8..ceb2a303a 100644 --- a/examples/fl/cert/project/trust.sh +++ b/examples/fl/cert/project/trust.sh @@ -32,11 +32,21 @@ if [ "$task" != "trust" ]; then exit 1 fi +export STEPPATH=$pki_assets/.step + CA_ADDRESS=$(jq -r '.address' $ca_config) CA_PORT=$(jq -r '.port' $ca_config) CA_FINGERPRINT=$(jq -r '.fingerprint' $ca_config) -# trust the CA. -rm -rf $pki_assets/* -step ca root $pki_assets/root_ca.crt --ca-url $CA_ADDRESS:$CA_PORT \ - --fingerprint $CA_FINGERPRINT +rm -rf $pki_assets/root_ca.crt + +if [ -n "$CA_FINGERPRINT" ]; then + # trust the CA. + step ca root $pki_assets/root_ca.crt --ca-url $CA_ADDRESS:$CA_PORT \ + --fingerprint $CA_FINGERPRINT +else + wget -O $pki_assets/root_ca.crt $CA_ADDRESS:$CA_PORT/roots.pem +fi + +# cleanup +rm -rf $STEPPATH diff --git a/examples/fl/cert/test.sh b/examples/fl/cert/test.sh new file mode 100644 index 000000000..e4862c8b6 --- /dev/null +++ b/examples/fl/cert/test.sh @@ -0,0 +1,6 @@ +medperf mlcube run --mlcube ./mlcube --task get_client_cert -e MEDPERF_INPUT_CN=hasan.kassem@mlcommons.org +medperf mlcube run --mlcube ./mlcube --task get_server_cert -e MEDPERF_INPUT_CN=34.41.173.238 -P 80 +# medperf mlcube run --mlcube ./mlcube --task get_server_cert +medperf mlcube run --mlcube ./mlcube --task trust +# docker run -it --entrypoint=/bin/bash --env MEDPERF_INPUT_CN=col1@example.com --volume '/home/hasan/work/medperf_ws/medperf/examples/fl/cert/mlcube/workspace:/mlcube_io0:ro' --volume '/home/hasan/work/medperf_ws/medperf/examples/fl/cert/mlcube/workspace/pki_assets:/mlcube_io1' mlcommons/medperf-step-cli:0.0.0 +# bash /mlcube_project/get_cert.sh get_client_cert --ca_config=/mlcube_io0/ca_config.json --pki_assets=/mlcube_io1 diff --git a/flca/Dockerfile.dev b/flca/Dockerfile.dev new file mode 100644 index 000000000..45e3b1340 --- /dev/null +++ b/flca/Dockerfile.dev @@ -0,0 +1,22 @@ +FROM python:3.11.9-alpine + +ENV USE_PROXY=1 + +# update openssl to fix https://avd.aquasec.com/nvd/cve-2024-2511 +RUN apk update && apk add openssl=3.1.4-r6 tar curl && if [[ -n "${USE_PROXY}" ]]; then apk add nginx; fi + +ARG VERSION=0.26.1 +RUN curl -LO https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step_linux_${VERSION}_amd64.tar.gz \ + && cp step_${VERSION}/bin/step /usr/bin +RUN curl -LO https://dl.smallstep.com/gh-release/certificates/gh-release-header/v${VERSION}/step-ca_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step-ca_linux_${VERSION}_amd64.tar.gz \ + && cp step-ca /usr/bin + + +COPY ./dev_utils.py /utils.py + +COPY ./setup.py /setup.py +COPY ./entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/bin/sh", "/entrypoint.sh"] diff --git a/flca/Dockerfile.prod b/flca/Dockerfile.prod new file mode 100644 index 000000000..6af8a8904 --- /dev/null +++ b/flca/Dockerfile.prod @@ -0,0 +1,22 @@ +FROM python:3.11.9-alpine + +ENV USE_PROXY=1 + +# update openssl to fix https://avd.aquasec.com/nvd/cve-2024-2511 +RUN apk update && apk add openssl=3.1.4-r6 tar curl && if [[ -n "${USE_PROXY}" ]]; then apk add nginx; fi + +ARG VERSION=0.26.1 +RUN curl -LO https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step_linux_${VERSION}_amd64.tar.gz \ + && cp step_${VERSION}/bin/step /usr/bin +RUN curl -LO https://dl.smallstep.com/gh-release/certificates/gh-release-header/v${VERSION}/step-ca_linux_${VERSION}_amd64.tar.gz \ + && tar -xf step-ca_linux_${VERSION}_amd64.tar.gz \ + && cp step-ca /usr/bin + +RUN pip install google-cloud-secret-manager==2.20.0 +COPY ./utils.py /utils.py + +COPY ./setup.py /setup.py +COPY ./entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["/bin/sh", "/entrypoint.sh"] \ No newline at end of file diff --git a/flca/README.md b/flca/README.md new file mode 100644 index 000000000..f5916a8f6 --- /dev/null +++ b/flca/README.md @@ -0,0 +1,80 @@ +# Deploying Step + +## For Production + +### Configuration and secrets + +#### `ca_config` + +An example of this file can be found in `dev_assets/ca.json` (or [here](https://smallstep.com/docs/step-ca/configuration/#basic-configuration-options)). This contains ca configuration, and will be modified during runtime as follows: + + * `root`: will point to the path of the root ca cert after it gets downloaded and stored. + * `crt`: will point to the path of the intermediate ca cert after it gets downloaded and stored. + * `key`: will point to the path of the intermediate ca key after it gets downloaded and stored. + * `db`: will contain the database configuration. (will be taken from a secret variable) + * `authority.provisioners.0.options.x509.templateFile`: will point to the path of the OIDC provisioner cert template after it gets downloaded and stored. + * `authority.provisioners.1.options.x509.templateFile`: will point to the path of the ACME provisioner cert template after it gets downloaded and stored. + +#### Other configuration files + + * `root_ca_crt`: The root ca certificate. + * `intermediate_ca_crt`: The intermediate ca certificate. + * `client_x509_template`: The OIDC provisioner cert template. + * `server_x509_template`: The ACME provisioner cert template. + * `proxy_config`: If a proxy need to be used, this contains an Nginx server configuration. + +#### Secrets + + * `intermediate_ca_key`: The intermediate ca key. + * `intermediate_ca_password`: The password used to encrypt the intermediate ca key. + * `db_config`: Database connection configuration. + +#### Main settings file + +All secrets and configurations are separately stored in GCP's secret manager. There is a main settings file `settings.json` that is also stored on the secret manager, and it is a JSON file that contains references to the other secrets/configurations. + +### Deployment + + * Build + + ```sh + docker build -t tmptag -f Dockerfile.prod . + ``` + + * tag + + ```sh + TAG=$(docker image ls | grep tmptag | tr -s " " | awk '{$1=$1;print}' | cut -d " " -f 3) + docker tag tmptag us-west1-docker.pkg.dev/medperf-330914/medperf-repo/medperf-ca:$TAG + ``` + + * Push + + ```sh + docker push us-west1-docker.pkg.dev/medperf-330914/medperf-repo/medperf-ca:$TAG + ``` + + * Setup secrets and configurations + * Edit `cloudbuild.yaml` as needed. You may change: + * the service account that will bind to the deployed instance. + * the port + * the service name if planning to deploy a new service, not a new revision of the existing service. + * SQL instance + * ... + * Run `gcloud builds submit --config=cloudbuild.yaml --substitutions=SHORT_SHA=$TAG` + +## For Development + +### Configuration and secrets + +The folder `dev_assets` contains configurations and ""secrets"" described above, but for development. + +### Deployment + +Build using `Dockerfile.dev` (tag it say with `local/devca:0.0.0`), then run: + +```sh +docker run --volume ./dev_assets:/assets -p :443:443 local/devca:0.0.0 +``` + +Set `` as you wish (`0.0.0.0`, `127.0.0.1`, `$(hostname -I | cut -d " " -f 1)`, ...) diff --git a/flca/cloudbuild.yaml b/flca/cloudbuild.yaml new file mode 100644 index 000000000..158cfc889 --- /dev/null +++ b/flca/cloudbuild.yaml @@ -0,0 +1,45 @@ +#The script is invoked manually with all settings provided in the secret +#It assumes that DB is created before the script run +#Inorder to deploy a service, pass sha-id of the already built image +#Command: gcloud builds submit --config=cloudbuild.yaml --substitutions=SHORT_SHA= +steps: + - id: "deploy cloud run" + name: "gcr.io/cloud-builders/gcloud" + args: + [ + "run", + "deploy", + "${_CLOUD_RUN_SERVICE_NAME}", + "--platform", + "managed", + "--region", + "${_REGION}", + "--image", + "${_REGION}-${_ARTIFACT_REGISTRY_DOMAIN}/${PROJECT_ID}/${_REPO_NAME}/${_IMAGE_NAME}:${SHORT_SHA}", + "--add-cloudsql-instances", + "${PROJECT_ID}:${_REGION}:${_SQL_INSTANCE_NAME}", + "--set-env-vars", + "SETTINGS_SECRETS_NAME=${_SECRET_SETTINGS_NAME}", + "--allow-unauthenticated", + "--min-instances", + "${_CLOUD_RUN_MIN_INSTANCES}", + "--port", + "${_PORT}", + "--service-account", + "${_SERVICE_ACCOUNT}" + ] + +substitutions: + _REGION: us-west1 + _ARTIFACT_REGISTRY_DOMAIN: docker.pkg.dev + _REPO_NAME: medperf-repo + _IMAGE_NAME: medperf-ca + _CLOUD_RUN_SERVICE_NAME: medperf-ca + _CLOUD_RUN_MIN_INSTANCES: "1" + _SECRET_SETTINGS_NAME: medperf-ca-settings + _SQL_INSTANCE_NAME: medperf-dev + _PORT: "443" + _SERVICE_ACCOUNT: "medperf-ca@medperf-330914.iam.gserviceaccount.com" + +options: + dynamic_substitutions: true diff --git a/flca/dev_assets/ca.json b/flca/dev_assets/ca.json new file mode 100644 index 000000000..8e6e191d3 --- /dev/null +++ b/flca/dev_assets/ca.json @@ -0,0 +1,57 @@ +{ + "root": "/stephome/certs/root_ca.crt", + "federatedRoots": null, + "crt": "/stephome/certs/intermediate_ca.crt", + "key": "/stephome/secrets/intermediate_ca_key", + "dnsNames": [ + "127.0.0.1" + ], + "address": "127.0.0.1:8000", + "logger": { + "format": "text" + }, + "db": "", + "authority": { + "provisioners": [ + { + "type": "ACME", + "name": "acme", + "options": { + "x509": { + "templateFile": "/stephome/templates/certs/x509/server.tpl" + } + } + }, + { + "type": "OIDC", + "name": "auth0", + "clientID": "kQoZ38ESRfUuMUUBlQRv2gWwOwGAMOqd", + "clientSecret": "", + "configurationEndpoint": "https://dev-5xl8y6uuc2hig2ly.us.auth0.com/.well-known/openid-configuration", + "options": { + "x509": { + "templateFile": "/stephome/templates/certs/x509/client.tpl" + }, + "ssh": {} + } + } + ], + "claims": { + "minTLSCertDuration": "8766h", + "maxTLSCertDuration": "8766h", + "defaultTLSCertDuration": "8766h", + "disableRenewal": true + }, + "template": {}, + "backdate": "1m0s" + }, + "tls": { + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" + ], + "minVersion": 1.2, + "maxVersion": 1.3, + "renegotiation": false + } +} \ No newline at end of file diff --git a/flca/dev_assets/client.tpl b/flca/dev_assets/client.tpl new file mode 100644 index 000000000..e4f3523f5 --- /dev/null +++ b/flca/dev_assets/client.tpl @@ -0,0 +1,10 @@ +{ + "subject": {{ toJson .Token.email }}, + "sans": {{ toJson .SANs }}, +{{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + "keyUsage": ["dataEncipherment", "digitalSignature"], +{{- else }} + {{ fail "Key type must be RSA. Try again with --kty=RSA" }} +{{- end }} + "extKeyUsage": ["serverAuth", "clientAuth"] +} \ No newline at end of file diff --git a/flca/dev_assets/db_config.json b/flca/dev_assets/db_config.json new file mode 100644 index 000000000..5b1cabc9c --- /dev/null +++ b/flca/dev_assets/db_config.json @@ -0,0 +1,4 @@ +{ + "type": "badgerv2", + "dataSource": "/db" +} \ No newline at end of file diff --git a/flca/dev_assets/intermediate_ca.crt b/flca/dev_assets/intermediate_ca.crt new file mode 100644 index 000000000..2619b1f26 --- /dev/null +++ b/flca/dev_assets/intermediate_ca.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIElTCCAsmgAwIBAgIRAM0SvcPdc4E0j1asWd+pCoQwQQYJKoZIhvcNAQEKMDSg +DzANBglghkgBZQMEAgEFAKEcMBoGCSqGSIb3DQEBCDANBglghkgBZQMEAgEFAKID +AgEgMBoxGDAWBgNVBAMTD01lZFBlcmYgUm9vdCBDQTAeFw0yNDA1MDkxMjU2NDZa +Fw0zNDA1MTAwMDU2NDFaMCIxIDAeBgNVBAMTF01lZFBlcmYgSW50ZXJtZWRpYXRl +IENBMIIBojANBgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEAyz2XEnXNheKF5ul0 +TgvMcvfAEmQh1IjrtNsisk1Jcep82bInK6GiKTvr4f1JzqelqEY1PrZnjcKtG+Q7 +AttYtAZi+mH5FnxDZAEvPSMd4Feo7y0QCvjsfe6jwUX6XToh+2ET89863xCd9JLi +iy9FvjbZWLx3trN1k/aOBlotKUXaLOHRaOO4GTpzZymN+HElsYSxXhxiWVAuQQF3 +GEEXAi1rlekNDt46cqI230M4rI9FWuRZTOaVsHm7OyyOXJt7inbKvWsgWF9+YaZS +lfhuuj6RKDpK32DOekzAR39mjdJf+EsXnIw0Jx3Mqcq2n6xjh9/72Z88CpVFrcEm +TeSnohAEUM6f1a4oTxYCl+FOV8RhplCoC/NAxSGbfmZ0y8WYNeGwSQyAFFpBX2rO +zsXIWGbclRLkzklbfTUf70Fi4hjOzHEPrGK8J5bIjiOk/l8dc3vQRu30OmUEYWRY +g9FleBS67roTijEtWN9V63MLMpouBJugTN/xO48VHHYbolkZAgMBAAGjZjBkMA4G +A1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMB0GA1UdDgQWBBSyaUMC +TOjmPZn7ZTDXG6Qo8qKi4TAfBgNVHSMEGDAWgBT0pIgmVPfaZgIy7Cx1yxRc7SRb +UzBBBgkqhkiG9w0BAQowNKAPMA0GCWCGSAFlAwQCAQUAoRwwGgYJKoZIhvcNAQEI +MA0GCWCGSAFlAwQCAQUAogMCASADggGBAIXTa3szlsPLh7cv5X/ZgFzSF/6dtRmC +1UlQz8YvQUU5qzI53fvvVCph0+I02YHjA8npospXB0v9FTwiU2TrdYR9Kld/JMzI +zQrkyZC1ePFiZrc7fLzTp8kuo78rUqgJY6TgGxBwH0GfwonJo9/wr0Xxnzex3h+P +3DobGjlcVkEpvgERL0JLvmWdi/Vh3saiaD6a+Lvy6c2dt5YpF6RHK8v0DEpzyfM4 +xDdkYA1z4i2sjZntshxpm21Q7K9I80jCv+MuqUPKcyVmlGlVfaxU0HynjPIY/53Y +cjIA8QVnW8Wf4iLn/Tg4To2YLXZx5R08iuyU/r3d5QWlqs24UZgJvqRxnkFnQE9n +wjXB/MtSG76VRFHNrcVAajNY4hfE4N4ZQN1IVoXA32a2arpMaAuZiLUmH8usJ1B0 +AvHP+tgKGJbxomDINqcA7kgf5SFMN3ocAhvBwvk3+NMMZqvb8KgiViavIe+A3Of0 +UaA1RyJhmgymzVjnDRSvVT5zddkpe6k38A== +-----END CERTIFICATE----- diff --git a/flca/dev_assets/intermediate_ca_key b/flca/dev_assets/intermediate_ca_key new file mode 100644 index 000000000..e8e1ef8ba --- /dev/null +++ b/flca/dev_assets/intermediate_ca_key @@ -0,0 +1,42 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,1417c765eac64d4dcc4606c745aa811e + +8yCvhJYJTWTSu72X7sbunRTfb2VwzyLtIokzF2jIIGnlXGcN5YtNTqHdoOWha94B +FPlyLuzvGlzSjpEtYIcBYNReY0l75KWIWJ0AP4OufHobyzJ5FttTHQPUh72v6iPt +bWWS68Y5tsfJLIlB0BBrkfIHH7eh8kEXN75MZnXAPc4d8RccN6SYoubCKWqNFvYb +9QloWLGVP58sd5ILSf5zE50KYO3ZMoDgjK3W0hTD5ocBtRYzD3m5toyFdA5Sxufc +Ea7JI2NJ8Nch1jLwW0MbtDC77yu1RfVGuJAfLZsQwyXPVT4QxOZQoRLWMeUuZSe9 ++EfNsjrwmNHu2h+z9BXVYq6DX5lpE0FBnzzIsOXCkmSsyDa0/7q3yAD63dGS9NF1 +hMypj840NBoF1X6PIdG0r49HhsfQ9witQh1H77ZM0HzSBN1jo9dmljgLocU+cAv5 +sPt0H/+BQzwOJG8U1SYfW0Sa7airYfY3mdZgqX/Ghzp2uZavmD6UYQgvHFaSX1j3 +n4EJUZpV489FyWRcd+0e/czzZ+QJcRsGUeYOvLMGKXZSKpG68/lMjZ8hHQ3t+DuN +3OqotHK6B5HEayMcQn6Hglscfq72UbiaqkXIwNBugSnwwLtO+ObHEyC1DSldgXpt +GQgwh9YHHbJioTy4ZQxoxjCkRnRWk8VSgEs1bH8MxVuVUsiJNrkY82AwwQ+Z5bpx +cC0niMIRuqKG2Ja0IzUnqhkAjYprYn55dIa05iCMtkNCT5DuQbRa8GlGqsH/pzN1 +jQuB5TTigcpnj8b97GXSDBpPej7pZ48vGZ8+CyIo72hd8nMYeZjcNp2rzM87Aa5G +15z36341apE/4yRya5nRMGXObdBTHiZfQzCrVZ2pIYqhGLew3UU98a9LJwR28+KO +cH5V67P9hOhHwj1/O6WKhDJBeyCrnEjDjXfY4/CzKzL72x8iy6ujGG71eu9Ton6J +mCEpXZlamIchBnVM1o1VM7C3PaMJnIwD8JsUt/G3siBKG+Xa9qehe2agkMvOKDVI +zpC7IjpzXjPQ9TpvCxQ+A7D6yNL4p3NjdVW6vlYKa58LESuXhxvr5QsfXUYk70B+ +95J/vpgWyoFadvWIYVAvjcrYRPlgUjJDf1tz/kxqkiWSbDnCMdPG7NeX+RTFY4IW +ifzvrSeTo2k4ceYKFMvWB7NmKAgPISvMDggM1irdrEuoYmOFofCnVRdxHOpXB3w0 +2MazMFXyXP2AEOkTU7FPyO5dRVrtbXIAggCfUpIE8CdseO6YO9k8LAKfehVHWIqX +H+SVUYhH/Ij/B5u7/YLVDtYfup7vzF637nOtI3LODo/XWFBZmFt/mpeTnF1v5UTU +irDi2cma4fcKs8FDIAgq8SYX/jyz9+7tD1yVhP8qXJqESmi04kHBs7D2O+MbiPg5 +gkHP0B8ZVwClieG/ovAJloZFUVcr+PgN+eeUK1VhBoN9GQotJ6oxO0YBaExbnGne +14TJArS9xBhBhmNh79J9QTaoSDro995uS4UWjEtm6vnwSU7X2U44pcY4O+cxlJjK +rHIjmAOa988F7WSGPg3PniFZ7LTU5qpcJ0blebrciYGhgk8YN7x58oncjkK7vIVV +OO4WJbLfkAbPX4UePiIp/3oMNZD+4ndChPcd0T8xlHaKkrB8jtUQk6Cfgeg5IWlZ +AD1FNyvrLDdigef/rKJ1VOjQjG4bPcdkSB7zo1MlWzHGJOQl0+U06m0whdAhhyIb +cC1MbeszAe+v8+9F2VaNEiRzQzEQAcTcdLOQoMB0q6WVE/G7QN88GMVV4KfqrK64 +jyH9tK2u8itmoxE5tSVwCU5Eys12li15/l1bg/LormlfO2r/nZX7L0ifjWWcaRAn +YFEhzVErS1/7671lINKN+tlJS7L8pHvfT9nhtFfZfUNaYZrhz/1RhijEyCKmrSZm +JDYqYl05lSqeLjHVDY5AAtGQzgrQRKDAD87AkbR+6365Pw8xnhzkxKZndeA7Zpxy +21Q4gsUK5EdIM0RYNDyV0c12pY9S5mN5V8DDTGH+IkdghiI1UX0QjmLDjAFrHrQ3 +9VKY+t+jZtk4AkkLxovDZMgwCPulhHvHKcXG6+v+vUJ2/W1B71JU1f+JzPZdBulT +yj1XiBPKo5FkIDShRs+iAjmbOlIoIaWhJB0G4VMFy0mhtIvaiq0GoDWitUT6VaUw +P85KzananB/5MCd9wzxb4quOxD0tbapaaOXfaLwFmM06BH+WFK/TtVitJiuR+W4z +MrFpHlbGQYQOtjhFC94yuC1oGviV1vqOpswMDntGCzB7dofossqUnxh7SUxFtDIP +PWJgcFW2qmBDUu8+n0SmfaGWqNTl4FpsvdrMwpdL5X+OAq80KXcUqWtGvu7PfRJU +-----END RSA PRIVATE KEY----- diff --git a/flca/dev_assets/pwd.txt b/flca/dev_assets/pwd.txt new file mode 100644 index 000000000..fde6f26b8 --- /dev/null +++ b/flca/dev_assets/pwd.txt @@ -0,0 +1 @@ +password! \ No newline at end of file diff --git a/flca/dev_assets/reverse_proxy.conf b/flca/dev_assets/reverse_proxy.conf new file mode 100644 index 000000000..a0065bf89 --- /dev/null +++ b/flca/dev_assets/reverse_proxy.conf @@ -0,0 +1,10 @@ +server { + listen 443; + location / { + proxy_pass https://127.0.0.1:8000; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } +} \ No newline at end of file diff --git a/flca/dev_assets/root_ca.crt b/flca/dev_assets/root_ca.crt new file mode 100644 index 000000000..790b5951c --- /dev/null +++ b/flca/dev_assets/root_ca.crt @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEazCCAp+gAwIBAgIQbHXgymGrqb/G85lgg0qzAzBBBgkqhkiG9w0BAQowNKAP +MA0GCWCGSAFlAwQCAQUAoRwwGgYJKoZIhvcNAQEIMA0GCWCGSAFlAwQCAQUAogMC +ASAwGjEYMBYGA1UEAxMPTWVkUGVyZiBSb290IENBMB4XDTI0MDUwOTEyNTYzNFoX +DTQ0MDUwOTEyNTYzM1owGjEYMBYGA1UEAxMPTWVkUGVyZiBSb290IENBMIIBojAN +BgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEAyPGt+k78Us5F61TL9AIt5rWVUUC+ +lGSZiZfDQktz43FfMKH6Kei7R4/9sTZoqKTa3XKyshMX23odrCyvKtpeBNnKQToo +yqDjLWx4wl4rcZatKIKSWou2Uhk77+ONa3T37ckAzgsHzECbRJtd/PsKk12o5PbI +0jgrcVZnXN897H78FDc5TxPQgdCN7v59URGJQ4e1H+dJlyaMF0bbN236PwQrlMA3 +kcnRZ1qsZ6R2eocFII7GCDmsdFZ3X0peg4hUz0Lf5pDYff6hf1R6VyfbqR+xJOdx +BHHO1Ak5UAxv1EYUflOZp11snqn3ZQCChRfDkoANWqdU0LIqNHlDn58YXPiXwnli +gItjyKxlphXKlFUi6juhWvW4nrYf378g3cO7echkn7NaHXudIxQp6xO8iETSUOEf +pGGuV72kapiZA9+GTFwYAK8MmuiHYxelZlqLl4fpUnxvZSrtr8m2n87a6AMRPWKS +yIxar684hQI6TRq9hJTs+HneeZT2CH7k9FqtAgMBAAGjRTBDMA4GA1UdDwEB/wQE +AwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBT0pIgmVPfaZgIy7Cx1 +yxRc7SRbUzBBBgkqhkiG9w0BAQowNKAPMA0GCWCGSAFlAwQCAQUAoRwwGgYJKoZI +hvcNAQEIMA0GCWCGSAFlAwQCAQUAogMCASADggGBAK4poY8+QbdDkahXHFWc4CeL +dffenE0FKUCDvR4+vQLoFjUKVcpdnkIrkouWSys2VmZusgvUofEKMk2vKVB/kzDF +Zg0GkIeBAZr0n10LImvwed1CzmA99K0J42W3Z8ksr0Qsr+1qmx9SyVv1xsMmIw1g +0CkpjkkpvNb9Z+nKbn5q3WBEcThM8llsi9krOaj9wzDVOGx1D7TcrB7IOTmCvkrE +TthFqz5EuUOHW0KVyQsDBd/ktxx5zkkHeAo9RrBifG4drmKF4ZnW9JcSdD74H7+2 +K1S9U3BkCdwSDfRpuuE4v1RP2ClrfdEF4//PjCnlw+7vEkPrOugEJZO5/IZzRApb +By5R8fgK6ChDUAWmx67CUb5PGG9ugGumVKpnk4Hwa+hixn5f19MeXTLEXW2Wmkp4 +SD1vsz9zN9HTk4dMpPJCO0oIWvbzEZcpCJt4kiBshzZK5GgP8PxGEqjxC6Vr/sUv +xgP8tlytyAGjZMcLPsXrMZMtlhxCEt7y9mMZ7NGLCA== +-----END CERTIFICATE----- diff --git a/flca/dev_assets/server.tpl b/flca/dev_assets/server.tpl new file mode 100644 index 000000000..78cc4d606 --- /dev/null +++ b/flca/dev_assets/server.tpl @@ -0,0 +1,10 @@ +{ + "subject": {{ toJson .Subject }}, + "sans": {{ toJson .SANs }}, +{{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + "keyUsage": ["dataEncipherment", "digitalSignature"], +{{- else }} + {{ fail "Key type must be RSA. Try again with --kty=RSA" }} +{{- end }} + "extKeyUsage": ["serverAuth", "clientAuth"] +} \ No newline at end of file diff --git a/flca/dev_assets/settings.json b/flca/dev_assets/settings.json new file mode 100644 index 000000000..e2aa5202f --- /dev/null +++ b/flca/dev_assets/settings.json @@ -0,0 +1,11 @@ +{ +"ca_config": "/assets/ca.json", +"intermediate_ca_key": "/assets/intermediate_ca_key", +"intermediate_ca_password": "/assets/pwd.txt", +"root_ca_crt": "/assets/root_ca.crt", +"intermediate_ca_crt": "/assets/intermediate_ca.crt", +"client_x509_template": "/assets/client.tpl", +"server_x509_template": "/assets/server.tpl", +"proxy_config": "/assets/reverse_proxy.conf", +"db_config": "/assets/db_config.json" +} \ No newline at end of file diff --git a/flca/dev_utils.py b/flca/dev_utils.py new file mode 100644 index 000000000..05378d0dd --- /dev/null +++ b/flca/dev_utils.py @@ -0,0 +1,23 @@ +import json +import os + + +def safe_store(content: str, path: str): + with open(path, "w") as f: + pass + os.chmod(path, 0o600) + with open(path, "a") as f: + f.write(content) + + +def get_all_secrets(): + # load settings + with open("/assets/settings.json") as f: + settings = json.load(f) + + # get secrets + secrets = {} + for key in settings.keys(): + with open(settings[key]) as f: + secrets[key] = f.read() + return secrets diff --git a/flca/entrypoint.sh b/flca/entrypoint.sh new file mode 100644 index 000000000..35325b77c --- /dev/null +++ b/flca/entrypoint.sh @@ -0,0 +1,13 @@ +export STEPPATH=$(step path) +python /setup.py +step-ca --password-file=$STEPPATH/secrets/pwd.txt $STEPPATH/config/ca.json & + +if [[ -n "$USE_PROXY" ]]; then + STATUS="1" + while [ "$STATUS" -ne "0" ]; do + sleep 1 + step ca health --ca-url 127.0.0.1:8000 + STATUS="$?" + done + nginx -g "daemon off;" +fi diff --git a/flca/manual_setup/README.md b/flca/manual_setup/README.md new file mode 100644 index 000000000..d4a37c4e5 --- /dev/null +++ b/flca/manual_setup/README.md @@ -0,0 +1 @@ +This contains code used to generate the root and intermediate ca certs and keys. You need to install `step` to use it (look at the dockerfiles). diff --git a/flca/manual_setup/create_keys.sh b/flca/manual_setup/create_keys.sh new file mode 100644 index 000000000..5dad840bd --- /dev/null +++ b/flca/manual_setup/create_keys.sh @@ -0,0 +1,17 @@ +step certificate create "MedPerf Root CA" \ + ./root_ca.crt \ + ./root_ca.key \ + --template ./rsa_root_ca.tpl \ + --kty RSA \ + --not-after 175320h \ + --size 3072 + +step certificate create "MedPerf Intermediate CA" \ + ./intermediate_ca.crt \ + ./intermediate_ca.key \ + --ca ./root_ca.crt \ + --ca-key ./root_ca.key \ + --template ./rsa_intermediate_ca.tpl \ + --kty RSA \ + --not-after 87660h \ + --size 3072 diff --git a/flca/manual_setup/rsa_intermediate_ca.tpl b/flca/manual_setup/rsa_intermediate_ca.tpl new file mode 100644 index 000000000..3f5606989 --- /dev/null +++ b/flca/manual_setup/rsa_intermediate_ca.tpl @@ -0,0 +1,12 @@ +{ + "subject": {{ toJson .Subject }}, + "issuer": {{ toJson .Subject }}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": { + "isCA": true, + "maxPathLen": 0 + } + {{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + , "signatureAlgorithm": "SHA256-RSAPSS" + {{- end }} +} \ No newline at end of file diff --git a/flca/manual_setup/rsa_root_ca.tpl b/flca/manual_setup/rsa_root_ca.tpl new file mode 100644 index 000000000..150812984 --- /dev/null +++ b/flca/manual_setup/rsa_root_ca.tpl @@ -0,0 +1,12 @@ +{ + "subject": {{ toJson .Subject }}, + "issuer": {{ toJson .Subject }}, + "keyUsage": ["certSign", "crlSign"], + "basicConstraints": { + "isCA": true, + "maxPathLen": 1 + } + {{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} + , "signatureAlgorithm": "SHA256-RSAPSS" + {{- end }} +} \ No newline at end of file diff --git a/flca/setup.py b/flca/setup.py new file mode 100644 index 000000000..06b9acd3f --- /dev/null +++ b/flca/setup.py @@ -0,0 +1,96 @@ +import os +import json +from utils import get_all_secrets, safe_store + + +def validate(secrets: dict): + """main settings are expected to be a json file that contains secrets IDs of other objects""" + expected_keys = set( + [ + "ca_config", + "intermediate_ca_key", + "intermediate_ca_password", + "root_ca_crt", + "intermediate_ca_crt", + "client_x509_template", + "server_x509_template", + "db_config", + ] + ) + if os.environ.get("USE_PROXY", None): + expected_keys.add("proxy_config") + + if expected_keys != set(secrets.keys()): + msg = "Expected keys: " + ", ".join(expected_keys) + msg += "\nFound keys: " + ", ".join(set(secrets.keys())) + raise ValueError(msg) + + +def setup(): + step_path = os.environ.get("STEPPATH", None) + if step_path is None: + raise Exception("STEPPATH var is not set") + + secrets = get_all_secrets() + validate(secrets) + + # Create folders + secrets_folder = os.path.join(step_path, "secrets") + certs_folder = os.path.join(step_path, "certs") + config_folder = os.path.join(step_path, "config") + templates_folder = os.path.join(step_path, "templates", "certs", "x509") + os.makedirs(secrets_folder, mode=0o600) + os.makedirs(certs_folder, mode=0o600) + os.makedirs(config_folder, mode=0o600) + os.makedirs(templates_folder, mode=0o600) + + # store key and its password + intermediate_ca_key_path = os.path.join(secrets_folder, "intermediate_ca_key") + safe_store(secrets["intermediate_ca_key"], intermediate_ca_key_path) + + password_path = os.path.join(secrets_folder, "pwd.txt") + safe_store(secrets["intermediate_ca_password"], password_path) + + # store root and intermediate certs + root_ca_crt_path = os.path.join(certs_folder, "root_ca.crt") + safe_store(secrets["root_ca_crt"], root_ca_crt_path) + + intermediate_ca_crt_path = os.path.join(certs_folder, "intermediate_ca.crt") + safe_store(secrets["intermediate_ca_crt"], intermediate_ca_crt_path) + + # store signing templates + client_tpl_path = os.path.join(templates_folder, "client.tpl") + safe_store(secrets["client_x509_template"], client_tpl_path) + + server_tpl_path = os.path.join(templates_folder, "server.tpl") + safe_store(secrets["server_x509_template"], server_tpl_path) + + # Get config + config = json.loads(secrets["ca_config"]) + + # Override config with runtime paths + config["root"] = root_ca_crt_path + config["crt"] = intermediate_ca_crt_path + config["key"] = intermediate_ca_key_path + # assuming server provisioner is the first one, and client is second + config["authority"]["provisioners"][0]["options"]["x509"][ + "templateFile" + ] = server_tpl_path + config["authority"]["provisioners"][1]["options"]["x509"][ + "templateFile" + ] = client_tpl_path + + # Override db config + config["db"] = json.loads(secrets["db_config"]) + + # write config + ca_config_path = os.path.join(config_folder, "ca.json") + safe_store(json.dumps(config), ca_config_path) + + # setup proxy + if os.environ.get("USE_PROXY", None): + safe_store(secrets["proxy_config"], "/etc/nginx/http.d/reverse-proxy.conf") + + +if __name__ == "__main__": + setup() diff --git a/flca/utils.py b/flca/utils.py new file mode 100644 index 000000000..f316c8baf --- /dev/null +++ b/flca/utils.py @@ -0,0 +1,49 @@ +import google.auth +from google.cloud import secretmanager +import json +import os + + +def get_secret(secret_name: str): + """Code copied and modified from medperf/server/medperf/settings.py""" + + try: + _, os.environ["GOOGLE_CLOUD_PROJECT"] = google.auth.default() + except google.auth.exceptions.DefaultCredentialsError: + raise Exception( + "No local .env or GOOGLE_CLOUD_PROJECT detected. No secrets found." + ) + + # Pull secrets from Secret Manager + print("Loading env from GCP secrets manager") + project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") + + client = secretmanager.SecretManagerServiceClient() + settings_version = os.environ.get("SETTINGS_SECRETS_VERSION", "latest") + name = f"projects/{project_id}/secrets/{secret_name}/versions/{settings_version}" + payload = client.access_secret_version(name=name).payload.data.decode("UTF-8") + return payload + + +def safe_store(content: str, path: str): + with open(path, "w") as f: + pass + os.chmod(path, 0o600) + with open(path, "a") as f: + f.write(content) + + +def get_all_secrets(): + settings_name = os.environ.get("SETTINGS_SECRETS_NAME", None) + if settings_name is None: + raise Exception("SETTINGS_SECRETS_NAME var is not set") + + # load settings + settings = get_secret(settings_name) + settings_dict: dict = json.loads(settings) + + # get secrets + secrets = {} + for key in settings_dict.keys(): + secrets[key] = get_secret(settings_dict[key]) + return secrets From 5a1256a9007ef32c26e4da83ddf785a9ac23dad2 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 21 May 2024 22:42:26 +0200 Subject: [PATCH 067/242] temporary measures for minimum UI friction --- cli/cli_tests_training.sh | 183 ++++++++---------------- cli/medperf/commands/dataset/dataset.py | 3 +- cli/medperf/commands/dataset/train.py | 68 ++++++--- server/training/views.py | 6 +- 4 files changed, 113 insertions(+), 147 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 7684fd48c..dfed35bbf 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -244,6 +244,34 @@ checkFailed "Data1 association step failed" echo "\n" +########################################################## +echo "=====================================" +echo "Get data1owner cert" +echo "=====================================" +medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get data1owner cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting training with data1" +echo "=====================================" +medperf dataset train -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & +COL1_PID=$! + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +# Check if the command is still running. +if [ ! -d "/proc/$COL1_PID" ]; then + checkFailed "data1 training doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Activate dataowner2 profile" @@ -297,101 +325,69 @@ echo "\n" ########################################################## echo "=====================================" -echo "Activate modelowner profile" -echo "=====================================" -medperf profile activate testmodel -checkFailed "testmodel profile activation failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Approve data1 association" -echo "=====================================" -medperf association approve -t $TRAINING_UID -d $DSET_1_UID -checkFailed "data1 association approval failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Approve data2 association" +echo "Get data2owner cert" echo "=====================================" -medperf association approve -t $TRAINING_UID -d $DSET_2_UID -checkFailed "data2 association approval failed" +medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get data2owner cert failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "start event" +echo "Starting training with data2" echo "=====================================" -medperf training start_event -n event1 -t $TRAINING_UID -y -checkFailed "start event failed" +medperf dataset train -d $DSET_2_UID -t $TRAINING_UID col2.log 2>&1 & +COL2_PID=$! -########################################################## - -echo "\n" +# sleep so that the mlcube is run before we change profiles +sleep 7 -########################################################## -echo "=====================================" -echo "Activate aggowner profile" -echo "=====================================" -medperf profile activate testagg -checkFailed "testagg profile activation failed" +# Check if the command is still running. +if [ ! -d "/proc/$COL2_PID" ]; then + checkFailed "data2 training doesn't seem to be running" 1 +fi ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Get aggregator cert" +echo "Activate modelowner profile" echo "=====================================" -medperf certificate get_server_certificate -t $TRAINING_UID -checkFailed "Get aggregator cert failed" +medperf profile activate testmodel +checkFailed "testmodel profile activation failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Activate dataowner profile" +echo "Approve data1 association" echo "=====================================" -medperf profile activate testdata1 -checkFailed "testdata1 profile activation failed" +medperf association approve -t $TRAINING_UID -d $DSET_1_UID +checkFailed "data1 association approval failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Get dataowner cert" +echo "Approve data2 association" echo "=====================================" -medperf certificate get_client_certificate -t $TRAINING_UID -checkFailed "Get dataowner cert failed" +medperf association approve -t $TRAINING_UID -d $DSET_2_UID +checkFailed "data2 association approval failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Activate dataowner2 profile" +echo "start event" echo "=====================================" -medperf profile activate testdata2 -checkFailed "testdata2 profile activation failed" -########################################################## - -echo "\n" +medperf training start_event -n event1 -t $TRAINING_UID -y +checkFailed "start event failed" -########################################################## -echo "=====================================" -echo "Get dataowner2 cert" -echo "=====================================" -medperf certificate get_client_certificate -t $TRAINING_UID -checkFailed "Get dataowner2 cert failed" ########################################################## echo "\n" @@ -406,71 +402,22 @@ checkFailed "testagg profile activation failed" echo "\n" -TRAINING_UID=1 -DSET_1_UID=1 -DSET_2_UID=2 -########################################################## -echo "=====================================" -echo "Starting aggregator" -echo "=====================================" -medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ agg.log 2>&1 & -AGG_PID=$! - -# sleep so that the mlcube is run before we change profiles -sleep 7 - -# Check if the command is still running. -if [ ! -d "/proc/$AGG_PID" ]; then - checkFailed "agg doesn't seem to be running" 1 -fi -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Activate dataowner profile" -echo "=====================================" -medperf profile activate testdata1 -checkFailed "testdata1 profile activation failed" -########################################################## - -echo "\n" - ########################################################## echo "=====================================" -echo "Starting training with data1" -echo "=====================================" -medperf dataset train -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & -COL1_PID=$! - -# sleep so that the mlcube is run before we change profiles -sleep 7 - -# Check if the command is still running. -if [ ! -d "/proc/$COL1_PID" ]; then - checkFailed "data1 training doesn't seem to be running" 1 -fi -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Activate dataowner2 profile" +echo "Get aggregator cert" echo "=====================================" -medperf profile activate testdata2 -checkFailed "testdata2 profile activation failed" +medperf certificate get_server_certificate -t $TRAINING_UID +checkFailed "Get aggregator cert failed" ########################################################## echo "\n" ########################################################## echo "=====================================" -echo "Starting training with data2" +echo "Starting aggregator" echo "=====================================" -medperf dataset train -d $DSET_2_UID -t $TRAINING_UID -checkFailed "data2 training failed" +medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ +checkFailed "agg didn't exit successfully" ########################################################## echo "\n" @@ -487,18 +434,8 @@ echo "=====================================" # a process that is not a child of the current shell wait $COL1_PID checkFailed "data1 training didn't exit successfully" -wait $AGG_PID -checkFailed "aggregator didn't exit successfully" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Activate modelowner profile" -echo "=====================================" -medperf profile activate testmodel -checkFailed "testmodel profile activation failed" +wait $COL2_PID +checkFailed "data2 training didn't exit successfully" ########################################################## echo "\n" diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index 77f4e0431..375f141ec 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -157,9 +157,10 @@ def train( overwrite: bool = typer.Option( False, "--overwrite", help="Overwrite outputs if present" ), + approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs training""" - TrainingExecution.run(training_exp_id, data_uid, overwrite) + TrainingExecution.run(training_exp_id, data_uid, overwrite, approval) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index 7fda13447..521ddcb48 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -3,45 +3,62 @@ from medperf.account_management.account_management import get_medperf_user_data from medperf.entities.ca import CA from medperf.entities.event import TrainingEvent -from medperf.exceptions import InvalidArgumentError, MedperfException +from medperf.exceptions import CleanExit, InvalidArgumentError, MedperfException from medperf.entities.training_exp import TrainingExp from medperf.entities.dataset import Dataset from medperf.entities.cube import Cube -from medperf.utils import get_pki_assets_path, get_participant_label, remove_path +from medperf.utils import ( + approval_prompt, + dict_pretty_print, + get_pki_assets_path, + get_participant_label, + remove_path, +) from medperf.certificates import trust class TrainingExecution: @classmethod - def run(cls, training_exp_id: int, data_uid: int, overwrite: bool = False): + def run( + cls, + training_exp_id: int, + data_uid: int, + overwrite: bool = False, + approved: bool = False, + ): """Starts the aggregation server of a training experiment Args: training_exp_id (int): Training experiment UID. """ - execution = cls(training_exp_id, data_uid, overwrite) + execution = cls(training_exp_id, data_uid, overwrite, approved) execution.prepare() execution.validate() execution.check_existing_outputs() - execution.prepare_training_cube() execution.prepare_plan() execution.prepare_pki_assets() + execution.confirm_run() + execution.prepare_training_cube() with config.ui.interactive(): execution.run_experiment() - def __init__(self, training_exp_id: int, data_uid: int, overwrite: bool) -> None: + def __init__( + self, training_exp_id: int, data_uid: int, overwrite: bool, approved: bool + ) -> None: self.training_exp_id = training_exp_id self.data_uid = data_uid self.overwrite = overwrite self.ui = config.ui + self.approved = approved def prepare(self): self.training_exp = TrainingExp.get(self.training_exp_id) self.ui.print(f"Training Execution: {self.training_exp.name}") - self.event = TrainingEvent.from_experiment(self.training_exp_id) + # self.event = TrainingEvent.from_experiment(self.training_exp_id) self.dataset = Dataset.get(self.data_uid) self.user_email: str = get_medperf_user_data()["email"] - self.out_logs = os.path.join(self.event.col_out_logs, str(self.dataset.id)) + # self.out_logs = os.path.join(self.event.col_out_logs, str(self.dataset.id)) + self.out_logs = os.path.join(self.training_exp.path, str(self.dataset.id)) def validate(self): if self.dataset.id is None: @@ -52,9 +69,9 @@ def validate(self): msg = "The provided dataset is not operational." raise InvalidArgumentError(msg) - if self.event.finished: - msg = "The provided training experiment has to start a training event." - raise InvalidArgumentError(msg) + # if self.event.finished: + # msg = "The provided training experiment has to start a training event." + # raise InvalidArgumentError(msg) def check_existing_outputs(self): msg = ( @@ -68,6 +85,26 @@ def check_existing_outputs(self): raise MedperfException(msg) remove_path(path) + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + + def confirm_run(self): + msg = ( + "Above is the training configuration that will be used." + " Do you confirm starting training? [Y/n] " + ) + dict_pretty_print(self.training_exp.plan) + self.approved = self.approved or approval_prompt(msg) + + if not self.approved: + raise CleanExit("Training cancelled.") + def prepare_training_cube(self): self.cube = self.__get_cube(self.training_exp.fl_mlcube, "FL") @@ -78,15 +115,6 @@ def __get_cube(self, uid: int, name: str) -> Cube: self.ui.print(f"> {name} cube download complete") return cube - def prepare_plan(self): - self.training_exp.prepare_plan() - - def prepare_pki_assets(self): - ca = CA.from_experiment(self.training_exp_id) - trust(ca) - self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) - self.ca = ca - def run_experiment(self): participant_label = get_participant_label(self.user_email, self.dataset.id) env_dict = {"MEDPERF_PARTICIPANT_LABEL": participant_label} diff --git a/server/training/views.py b/server/training/views.py index 0a9a75d54..f28e86cd8 100644 --- a/server/training/views.py +++ b/server/training/views.py @@ -108,9 +108,9 @@ def get(self, request, pk, format=None): class TrainingCA(GenericAPIView): - permission_classes = [ - IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner - ] + # permission_classes = [ + # IsAdmin | IsExpOwner | IsAssociatedDatasetOwner | IsAggregatorOwner + # ] serializer_class = CASerializer queryset = "" From 3dd45ffdb6e8abbd39ebb0f13a8161a6792e1281 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 21 May 2024 22:44:32 +0200 Subject: [PATCH 068/242] empty From 712bcf618f7838e84892944ff016161f011620ae Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 21 May 2024 22:56:45 +0200 Subject: [PATCH 069/242] post-merge main --- cli/cli_tests_training.sh | 118 +++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index dfed35bbf..0172711ce 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -9,16 +9,16 @@ echo "==========================================" echo "Creating test profiles for each user" echo "==========================================" -medperf profile activate local +print_eval medperf profile activate local checkFailed "local profile creation failed" -medperf profile create -n testmodel +print_eval medperf profile create -n testmodel checkFailed "testmodel profile creation failed" -medperf profile create -n testagg +print_eval medperf profile create -n testagg checkFailed "testagg profile creation failed" -medperf profile create -n testdata1 +print_eval medperf profile create -n testdata1 checkFailed "testdata1 profile creation failed" -medperf profile create -n testdata2 +print_eval medperf profile create -n testdata2 checkFailed "testdata2 profile creation failed" ########################################################## @@ -48,28 +48,28 @@ echo "\n" echo "==========================================" echo "Login each user" echo "==========================================" -medperf profile activate testmodel +print_eval medperf profile activate testmodel checkFailed "testmodel profile activation failed" -medperf auth login -e $MODELOWNER +print_eval medperf auth login -e $MODELOWNER checkFailed "testmodel login failed" -medperf profile activate testagg +print_eval medperf profile activate testagg checkFailed "testagg profile activation failed" -medperf auth login -e $AGGOWNER +print_eval medperf auth login -e $AGGOWNER checkFailed "testagg login failed" -medperf profile activate testdata1 +print_eval medperf profile activate testdata1 checkFailed "testdata1 profile activation failed" -medperf auth login -e $DATAOWNER +print_eval medperf auth login -e $DATAOWNER checkFailed "testdata1 login failed" -medperf profile activate testdata2 +print_eval medperf profile activate testdata2 checkFailed "testdata2 profile activation failed" -medperf auth login -e $DATAOWNER2 +print_eval medperf auth login -e $DATAOWNER2 checkFailed "testdata2 login failed" ########################################################## @@ -79,7 +79,7 @@ echo "\n" echo "=====================================" echo "Activate modelowner profile" echo "=====================================" -medperf profile activate testmodel +print_eval medperf profile activate testmodel checkFailed "testmodel profile activation failed" ########################################################## @@ -90,11 +90,11 @@ echo "=====================================" echo "Submit cubes" echo "=====================================" -medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE --operational +print_eval medperf mlcube submit --name trainprep -m $PREP_TRAINING_MLCUBE --operational checkFailed "Train prep submission failed" PREP_UID=$(medperf mlcube ls | grep trainprep | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) -medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -a $TRAIN_WEIGHTS --operational +print_eval medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -a $TRAIN_WEIGHTS --operational checkFailed "traincube submission failed" TRAINCUBE_UID=$(medperf mlcube ls | grep traincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) ########################################################## @@ -105,7 +105,7 @@ echo "\n" echo "=====================================" echo "Submit Training Experiment" echo "=====================================" -medperf training submit -n trainexp -d trainexp -p $PREP_UID -m $TRAINCUBE_UID +print_eval medperf training submit -n trainexp -d trainexp -p $PREP_UID -m $TRAINCUBE_UID checkFailed "Training exp submission failed" TRAINING_UID=$(medperf training ls | grep trainexp | tail -n 1 | tr -s ' ' | cut -d ' ' -f 2) @@ -123,7 +123,7 @@ echo "=====================================" echo "Associate with ca" echo "=====================================" CA_UID=$(medperf ca ls | grep "MedPerf CA" | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) -medperf ca associate -t $TRAINING_UID -c $CA_UID -y +print_eval medperf ca associate -t $TRAINING_UID -c $CA_UID -y checkFailed "ca association failed" ########################################################## @@ -133,7 +133,7 @@ echo "\n" echo "=====================================" echo "Activate aggowner profile" echo "=====================================" -medperf profile activate testagg +print_eval medperf profile activate testagg checkFailed "testagg profile activation failed" ########################################################## @@ -145,7 +145,7 @@ echo "Running aggregator submission step" echo "=====================================" HOSTNAME_=$(hostname -I | cut -d " " -f 1) # HOSTNAME_=$(hostname -A | cut -d " " -f 1) # fqdn on github CI runner doesn't resolve from inside containers -medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID +print_eval medperf aggregator submit -n aggreg -a $HOSTNAME_ -p 50273 -m $TRAINCUBE_UID checkFailed "aggregator submission step failed" AGG_UID=$(medperf aggregator ls | grep aggreg | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## @@ -156,7 +156,7 @@ echo "\n" echo "=====================================" echo "Running aggregator association step" echo "=====================================" -medperf aggregator associate -a $AGG_UID -t $TRAINING_UID -y +print_eval medperf aggregator associate -a $AGG_UID -t $TRAINING_UID -y checkFailed "aggregator association step failed" ########################################################## @@ -166,7 +166,7 @@ echo "\n" echo "=====================================" echo "Activate modelowner profile" echo "=====================================" -medperf profile activate testmodel +print_eval medperf profile activate testmodel checkFailed "testmodel profile activation failed" ########################################################## @@ -176,7 +176,7 @@ echo "\n" echo "=====================================" echo "Approve aggregator association" echo "=====================================" -medperf association approve -t $TRAINING_UID -a $AGG_UID +print_eval medperf association approve -t $TRAINING_UID -a $AGG_UID checkFailed "agg association approval failed" ########################################################## @@ -186,7 +186,7 @@ echo "\n" echo "=====================================" echo "submit plan" echo "=====================================" -medperf training set_plan -t $TRAINING_UID -c $TRAINING_CONFIG -y +print_eval medperf training set_plan -t $TRAINING_UID -c $TRAINING_CONFIG -y checkFailed "submit plan failed" ########################################################## @@ -197,7 +197,7 @@ echo "\n" echo "=====================================" echo "Activate dataowner profile" echo "=====================================" -medperf profile activate testdata1 +print_eval medperf profile activate testdata1 checkFailed "testdata1 profile activation failed" ########################################################## @@ -207,7 +207,7 @@ echo "\n" echo "=====================================" echo "Running data1 submission step" echo "=====================================" -medperf dataset submit -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" -y +print_eval medperf dataset submit -p $PREP_UID -d $DIRECTORY/col1 -l $DIRECTORY/col1 --name="col1" --description="col1data" --location="col1location" -y checkFailed "Data1 submission step failed" DSET_1_UID=$(medperf dataset ls | grep col1 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## @@ -218,7 +218,7 @@ echo "\n" echo "=====================================" echo "Running data1 preparation step" echo "=====================================" -medperf dataset prepare -d $DSET_1_UID +print_eval medperf dataset prepare -d $DSET_1_UID checkFailed "Data1 preparation step failed" ########################################################## @@ -228,7 +228,7 @@ echo "\n" echo "=====================================" echo "Running data1 set_operational step" echo "=====================================" -medperf dataset set_operational -d $DSET_1_UID -y +print_eval medperf dataset set_operational -d $DSET_1_UID -y checkFailed "Data1 set_operational step failed" ########################################################## @@ -238,7 +238,7 @@ echo "\n" echo "=====================================" echo "Running data1 association step" echo "=====================================" -medperf dataset associate -d $DSET_1_UID -t $TRAINING_UID -y +print_eval medperf dataset associate -d $DSET_1_UID -t $TRAINING_UID -y checkFailed "Data1 association step failed" ########################################################## @@ -248,7 +248,7 @@ echo "\n" echo "=====================================" echo "Get data1owner cert" echo "=====================================" -medperf certificate get_client_certificate -t $TRAINING_UID +print_eval medperf certificate get_client_certificate -t $TRAINING_UID checkFailed "Get data1owner cert failed" ########################################################## @@ -258,7 +258,7 @@ echo "\n" echo "=====================================" echo "Starting training with data1" echo "=====================================" -medperf dataset train -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & +print_eval medperf dataset train -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & COL1_PID=$! # sleep so that the mlcube is run before we change profiles @@ -276,7 +276,7 @@ echo "\n" echo "=====================================" echo "Activate dataowner2 profile" echo "=====================================" -medperf profile activate testdata2 +print_eval medperf profile activate testdata2 checkFailed "testdata2 profile activation failed" ########################################################## @@ -286,7 +286,7 @@ echo "\n" echo "=====================================" echo "Running data2 submission step" echo "=====================================" -medperf dataset submit -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" -y +print_eval medperf dataset submit -p $PREP_UID -d $DIRECTORY/col2 -l $DIRECTORY/col2 --name="col2" --description="col2data" --location="col2location" -y checkFailed "Data2 submission step failed" DSET_2_UID=$(medperf dataset ls | grep col2 | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) ########################################################## @@ -297,7 +297,7 @@ echo "\n" echo "=====================================" echo "Running data2 preparation step" echo "=====================================" -medperf dataset prepare -d $DSET_2_UID +print_eval medperf dataset prepare -d $DSET_2_UID checkFailed "Data2 preparation step failed" ########################################################## @@ -307,7 +307,7 @@ echo "\n" echo "=====================================" echo "Running data2 set_operational step" echo "=====================================" -medperf dataset set_operational -d $DSET_2_UID -y +print_eval medperf dataset set_operational -d $DSET_2_UID -y checkFailed "Data2 set_operational step failed" ########################################################## @@ -317,7 +317,7 @@ echo "\n" echo "=====================================" echo "Running data2 association step" echo "=====================================" -medperf dataset associate -d $DSET_2_UID -t $TRAINING_UID -y +print_eval medperf dataset associate -d $DSET_2_UID -t $TRAINING_UID -y checkFailed "Data2 association step failed" ########################################################## @@ -327,7 +327,7 @@ echo "\n" echo "=====================================" echo "Get data2owner cert" echo "=====================================" -medperf certificate get_client_certificate -t $TRAINING_UID +print_eval medperf certificate get_client_certificate -t $TRAINING_UID checkFailed "Get data2owner cert failed" ########################################################## @@ -337,7 +337,7 @@ echo "\n" echo "=====================================" echo "Starting training with data2" echo "=====================================" -medperf dataset train -d $DSET_2_UID -t $TRAINING_UID col2.log 2>&1 & +print_eval medperf dataset train -d $DSET_2_UID -t $TRAINING_UID col2.log 2>&1 & COL2_PID=$! # sleep so that the mlcube is run before we change profiles @@ -355,7 +355,7 @@ echo "\n" echo "=====================================" echo "Activate modelowner profile" echo "=====================================" -medperf profile activate testmodel +print_eval medperf profile activate testmodel checkFailed "testmodel profile activation failed" ########################################################## @@ -365,7 +365,7 @@ echo "\n" echo "=====================================" echo "Approve data1 association" echo "=====================================" -medperf association approve -t $TRAINING_UID -d $DSET_1_UID +print_eval medperf association approve -t $TRAINING_UID -d $DSET_1_UID checkFailed "data1 association approval failed" ########################################################## @@ -375,7 +375,7 @@ echo "\n" echo "=====================================" echo "Approve data2 association" echo "=====================================" -medperf association approve -t $TRAINING_UID -d $DSET_2_UID +print_eval medperf association approve -t $TRAINING_UID -d $DSET_2_UID checkFailed "data2 association approval failed" ########################################################## @@ -385,7 +385,7 @@ echo "\n" echo "=====================================" echo "start event" echo "=====================================" -medperf training start_event -n event1 -t $TRAINING_UID -y +print_eval medperf training start_event -n event1 -t $TRAINING_UID -y checkFailed "start event failed" ########################################################## @@ -396,7 +396,7 @@ echo "\n" echo "=====================================" echo "Activate aggowner profile" echo "=====================================" -medperf profile activate testagg +print_eval medperf profile activate testagg checkFailed "testagg profile activation failed" ########################################################## @@ -406,7 +406,7 @@ echo "\n" echo "=====================================" echo "Get aggregator cert" echo "=====================================" -medperf certificate get_server_certificate -t $TRAINING_UID +print_eval medperf certificate get_server_certificate -t $TRAINING_UID checkFailed "Get aggregator cert failed" ########################################################## @@ -416,7 +416,7 @@ echo "\n" echo "=====================================" echo "Starting aggregator" echo "=====================================" -medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ +print_eval medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ checkFailed "agg didn't exit successfully" ########################################################## @@ -444,7 +444,7 @@ echo "\n" echo "=====================================" echo "close event" echo "=====================================" -medperf training close_event -t $TRAINING_UID -y +print_eval medperf training close_event -t $TRAINING_UID -y checkFailed "close event failed" ########################################################## @@ -455,28 +455,28 @@ echo "\n" echo "=====================================" echo "Logout users" echo "=====================================" -medperf profile activate testmodel +print_eval medperf profile activate testmodel checkFailed "testmodel profile activation failed" -medperf auth logout +print_eval medperf auth logout checkFailed "logout failed" -medperf profile activate testagg +print_eval medperf profile activate testagg checkFailed "testagg profile activation failed" -medperf auth logout +print_eval medperf auth logout checkFailed "logout failed" -medperf profile activate testdata1 +print_eval medperf profile activate testdata1 checkFailed "testdata1 profile activation failed" -medperf auth logout +print_eval medperf auth logout checkFailed "logout failed" -medperf profile activate testdata2 +print_eval medperf profile activate testdata2 checkFailed "testdata2 profile activation failed" -medperf auth logout +print_eval medperf auth logout checkFailed "logout failed" ########################################################## @@ -486,19 +486,19 @@ echo "\n" echo "=====================================" echo "Delete test profiles" echo "=====================================" -medperf profile activate default +print_eval medperf profile activate default checkFailed "default profile activation failed" -medperf profile delete testmodel +print_eval medperf profile delete testmodel checkFailed "Profile deletion failed" -medperf profile delete testagg +print_eval medperf profile delete testagg checkFailed "Profile deletion failed" -medperf profile delete testdata1 +print_eval medperf profile delete testdata1 checkFailed "Profile deletion failed" -medperf profile delete testdata2 +print_eval medperf profile delete testdata2 checkFailed "Profile deletion failed" ########################################################## From e76bea651fdc88f983a4867c4ede25b32c6b592d Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 21 May 2024 23:03:58 +0200 Subject: [PATCH 070/242] skip approval in tests --- cli/cli_tests_training.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 0172711ce..fcfd7b4f4 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -258,7 +258,7 @@ echo "\n" echo "=====================================" echo "Starting training with data1" echo "=====================================" -print_eval medperf dataset train -d $DSET_1_UID -t $TRAINING_UID col1.log 2>&1 & +print_eval medperf dataset train -d $DSET_1_UID -t $TRAINING_UID -y col1.log 2>&1 & COL1_PID=$! # sleep so that the mlcube is run before we change profiles @@ -337,7 +337,7 @@ echo "\n" echo "=====================================" echo "Starting training with data2" echo "=====================================" -print_eval medperf dataset train -d $DSET_2_UID -t $TRAINING_UID col2.log 2>&1 & +print_eval medperf dataset train -d $DSET_2_UID -t $TRAINING_UID -y col2.log 2>&1 & COL2_PID=$! # sleep so that the mlcube is run before we change profiles From 97b1bacf67d13c9fef3656840a50e8297e5cc6fa Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 21 May 2024 23:48:27 +0200 Subject: [PATCH 071/242] bug fix for association ls --- cli/medperf/commands/association/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cli/medperf/commands/association/utils.py b/cli/medperf/commands/association/utils.py index 83b785c5e..78ff742d7 100644 --- a/cli/medperf/commands/association/utils.py +++ b/cli/medperf/commands/association/utils.py @@ -87,11 +87,17 @@ def get_associations_list( "user": config.comms.get_user_training_datasets_associations, "experiment": config.comms.get_training_datasets_associations, }, - "aggregator": config.comms.get_user_training_aggregators_associations, - "ca": config.comms.get_user_training_cas_associations, + "aggregator": { + "user": config.comms.get_user_training_aggregators_associations, + }, + "ca": { + "user": config.comms.get_user_training_cas_associations, + }, }, "benchmark": { - "dataset": config.comms.get_user_benchmarks_datasets_associations, + "dataset": { + "user": config.comms.get_user_benchmarks_datasets_associations, + }, "mode_mlcube": { "user": config.comms.get_user_benchmarks_models_associations, "experiment": config.comms.get_benchmark_models_associations, From af650e5bd649fcc823d6f601c29f12755df4310e Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 21 May 2024 23:52:23 +0200 Subject: [PATCH 072/242] better UI for data owners --- cli/medperf/commands/dataset/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index 521ddcb48..231db58d9 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -38,8 +38,8 @@ def run( execution.prepare_plan() execution.prepare_pki_assets() execution.confirm_run() - execution.prepare_training_cube() with config.ui.interactive(): + execution.prepare_training_cube() execution.run_experiment() def __init__( @@ -109,7 +109,9 @@ def prepare_training_cube(self): self.cube = self.__get_cube(self.training_exp.fl_mlcube, "FL") def __get_cube(self, uid: int, name: str) -> Cube: - self.ui.text = f"Retrieving {name} cube" + self.ui.text = ( + "Retrieving and setting up training MLCube. This may take some time." + ) cube = Cube.get(uid) cube.download_run_files() self.ui.print(f"> {name} cube download complete") From 4777e60c1877169f43dab2e78835ff7894d6571c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 22 May 2024 21:06:24 +0200 Subject: [PATCH 073/242] config storage migration --- cli/medperf/storage/__init__.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/cli/medperf/storage/__init__.py b/cli/medperf/storage/__init__.py index acebb6f31..59097b2ff 100644 --- a/cli/medperf/storage/__init__.py +++ b/cli/medperf/storage/__init__.py @@ -2,7 +2,7 @@ import shutil from medperf import config -from medperf.config_management import read_config, write_config +from medperf.config_management import read_config, write_config, ConfigManager from .utils import full_folder_path @@ -19,12 +19,7 @@ def init_storage(): os.makedirs(folder, exist_ok=True) -def apply_configuration_migrations(): - if not os.path.exists(config.config_path): - return - - config_p = read_config() - +def __apply_logs_migrations(config_p: ConfigManager): if "logs_folder" not in config_p.storage: return @@ -35,4 +30,27 @@ def apply_configuration_migrations(): del config_p.storage["logs_folder"] + +def __apply_training_migrations(config_p: ConfigManager): + + for folder in [ + "aggregators_folder", + "cas_folder", + "training_events_folder", + "training_folder", + ]: + if folder not in config_p.storage: + # Assuming for now all folders are always moved together + # I used here "benchmarks_folder" arbitrarily + config_p.storage[folder] = config_p.storage["benchmarks_folder"] + + +def apply_configuration_migrations(): + if not os.path.exists(config.config_path): + return + + config_p = read_config() + __apply_logs_migrations(config_p) + __apply_training_migrations(config_p) + write_config(config_p) From 538ea4e9e280e7a3cca89a44487c6e254ac27c6a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 22 May 2024 21:06:48 +0200 Subject: [PATCH 074/242] update medperf version --- cli/medperf/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/medperf/_version.py b/cli/medperf/_version.py index ae7362549..bbab0242f 100644 --- a/cli/medperf/_version.py +++ b/cli/medperf/_version.py @@ -1 +1 @@ -__version__ = "0.1.3" +__version__ = "0.1.4" From 6b42235055dbf9a685602b017cebe76c5f05be97 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 22 May 2024 21:07:26 +0200 Subject: [PATCH 075/242] use check update again --- cli/medperf/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 242d9b734..06e2608a4 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -22,7 +22,7 @@ import medperf.commands.certificate.certificate as certificate import medperf.commands.storage as storage -# from medperf.utils import check_for_updates +from medperf.utils import check_for_updates from medperf.logging.utils import log_machine_details app = typer.Typer() @@ -111,6 +111,6 @@ def main( logging.info(f"Running MedPerf v{__version__} on {loglevel} logging level") logging.info(f"Executed command: {' '.join(sys.argv[1:])}") log_machine_details() - # check_for_updates() + check_for_updates() config.ui.print(f"MedPerf {__version__}") From 6d2e16ee85b96832cf534ca4488289eadbd5cce8 Mon Sep 17 00:00:00 2001 From: Micah Sheller Date: Tue, 11 Jun 2024 15:13:06 -0700 Subject: [PATCH 076/242] Draft example of NNUNet Integration with OpenFL and MedPerf --- cli/medperf/commands/dataset/train.py | 2 +- examples/fl_post/fl/.gitignore | 5 + examples/fl_post/fl/.test.sh.swp | Bin 0 -> 12288 bytes examples/fl_post/fl/README.md | 6 + examples/fl_post/fl/build.sh | 1 + examples/fl_post/fl/clean.sh | 9 + examples/fl_post/fl/csr.conf | 23 + examples/fl_post/fl/mlcube/mlcube.yaml | 48 ++ .../fl/mlcube/workspace/training_config.yaml | 59 +++ examples/fl_post/fl/project/Dockerfile | 19 + examples/fl_post/fl/project/README.md | 38 ++ examples/fl_post/fl/project/aggregator.py | 60 +++ examples/fl_post/fl/project/collaborator.py | 32 ++ examples/fl_post/fl/project/hooks.py | 109 +++++ examples/fl_post/fl/project/mlcube.py | 128 ++++++ .../fl_post/fl/project/nnunet_data_setup.py | 417 ++++++++++++++++++ .../fl_post/fl/project/nnunet_model_setup.py | 181 ++++++++ examples/fl_post/fl/project/nnunet_setup.py | 228 ++++++++++ examples/fl_post/fl/project/plan.py | 16 + examples/fl_post/fl/project/requirements.txt | 3 + examples/fl_post/fl/project/src/__init__.py | 3 + .../src/launch_collaborator_with_env_vars.sh | 1 + .../fl/project/src/nnunet_dummy_dataloader.py | 36 ++ examples/fl_post/fl/project/src/nnunet_v1.py | 278 ++++++++++++ .../fl_post/fl/project/src/runner_nnunetv1.py | 233 ++++++++++ .../fl_post/fl/project/src/runner_pt_chkpt.py | 321 ++++++++++++++ examples/fl_post/fl/project/utils.py | 132 ++++++ examples/fl_post/fl/setup_clean.sh | 5 + examples/fl_post/fl/setup_test.sh | 124 ++++++ examples/fl_post/fl/sync.sh | 6 + examples/fl_post/fl/test_agg.sh | 27 ++ examples/fl_post/fl/test_col1.sh | 25 ++ examples/fl_post/fl/test_col2.sh | 25 ++ 33 files changed, 2599 insertions(+), 1 deletion(-) create mode 100644 examples/fl_post/fl/.gitignore create mode 100644 examples/fl_post/fl/.test.sh.swp create mode 100644 examples/fl_post/fl/README.md create mode 100755 examples/fl_post/fl/build.sh create mode 100755 examples/fl_post/fl/clean.sh create mode 100644 examples/fl_post/fl/csr.conf create mode 100644 examples/fl_post/fl/mlcube/mlcube.yaml create mode 100644 examples/fl_post/fl/mlcube/workspace/training_config.yaml create mode 100644 examples/fl_post/fl/project/Dockerfile create mode 100644 examples/fl_post/fl/project/README.md create mode 100644 examples/fl_post/fl/project/aggregator.py create mode 100644 examples/fl_post/fl/project/collaborator.py create mode 100644 examples/fl_post/fl/project/hooks.py create mode 100644 examples/fl_post/fl/project/mlcube.py create mode 100644 examples/fl_post/fl/project/nnunet_data_setup.py create mode 100644 examples/fl_post/fl/project/nnunet_model_setup.py create mode 100644 examples/fl_post/fl/project/nnunet_setup.py create mode 100644 examples/fl_post/fl/project/plan.py create mode 100644 examples/fl_post/fl/project/requirements.txt create mode 100644 examples/fl_post/fl/project/src/__init__.py create mode 100644 examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh create mode 100644 examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py create mode 100644 examples/fl_post/fl/project/src/nnunet_v1.py create mode 100644 examples/fl_post/fl/project/src/runner_nnunetv1.py create mode 100644 examples/fl_post/fl/project/src/runner_pt_chkpt.py create mode 100644 examples/fl_post/fl/project/utils.py create mode 100644 examples/fl_post/fl/setup_clean.sh create mode 100644 examples/fl_post/fl/setup_test.sh create mode 100755 examples/fl_post/fl/sync.sh create mode 100755 examples/fl_post/fl/test_agg.sh create mode 100755 examples/fl_post/fl/test_col1.sh create mode 100755 examples/fl_post/fl/test_col2.sh diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index bb67c09af..f22bb86c4 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -73,7 +73,7 @@ def prepare_plan(self): def prepare_pki_assets(self): ca = CA.from_experiment(self.training_exp_id) - # trust(ca) + trust(ca) self.dataset_pki_assets = get_pki_assets_path(self.user_email, ca.name) self.ca = ca diff --git a/examples/fl_post/fl/.gitignore b/examples/fl_post/fl/.gitignore new file mode 100644 index 000000000..70b318917 --- /dev/null +++ b/examples/fl_post/fl/.gitignore @@ -0,0 +1,5 @@ +mlcube_* +ca +quick* +mlcube/workspace/additional_files/init_nnunet/* +mlcube/workspace/additional_files/init_weights/* \ No newline at end of file diff --git a/examples/fl_post/fl/.test.sh.swp b/examples/fl_post/fl/.test.sh.swp new file mode 100644 index 0000000000000000000000000000000000000000..327381c6955512a7f275e4252a493383572270b1 GIT binary patch literal 12288 zcmeI2L2uJA6vsV;5J(IVh`VW#I3;P)O-w_R7&}ZzgQ1EcaSCqgCS^%%+3rR!Ai)W7 z=D?jVuulL&;xiyV3m2Z-q{{|$K~oQ)7wIW=-n0Mlug}k_>Tc`F{Rec#^$A>82-!Sv z`#bK(bMHE5Nv0E?rd(OqtV>Jf#B$ggk1-?gWjUZ14~&=bp5&_UBuum1<1bhyQ?9*! z8cLxJ+&#m!adk3dM06nmBv6aM(Da>^#hrHGw`_c~af4pFx>Ktw1|b0?fCP{L5{EBdvJ@i3(8!?)QxWURz{{jbdm`H9imIV~ZJzJZjqR%&@$ZD<;Uu2OMB3=0B6VaY!@wA#63%E zwe$g~RVC)2GGSlQ@#mv6{U$3?!F;5Otf&aJ=OR=QSelD~XEU>+vP_+U87HJsBPj;C zph_&+yKz!nP&3Z2 z6Suy(S%wk&4r`6#FtVKql#pWG6%Dss1F$~qx;9MjmL-M%Zeg~o3$tMrg8B-!temJ2 z{G#C3SFm;MM1>>8r9psLO#v3i<@nk#D4g&>@t7H*sPl*}w}RX42B96OY<5X%`1vU5 z6fp!>D$rrAoCetM;4tE^`v#}zTd9piLBUnCUYlflv)q+!U^=xUb0@VjvI0BiIakcU f_6Hjb>>y>E+BRi*k46HvLRh~>P|l)cjGz1jYQDe5 literal 0 HcmV?d00001 diff --git a/examples/fl_post/fl/README.md b/examples/fl_post/fl/README.md new file mode 100644 index 000000000..918f483e3 --- /dev/null +++ b/examples/fl_post/fl/README.md @@ -0,0 +1,6 @@ +# How to run tests + +- Run `setup_test.sh` just once to create certs and download required data. +- Run `test.sh` to start the aggregator and three collaborators. +- Run `clean.sh` to be able to rerun `test.sh` freshly. +- Run `setup_clean.sh` to clear what has been generated in step 1. diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh new file mode 100755 index 000000000..d56304274 --- /dev/null +++ b/examples/fl_post/fl/build.sh @@ -0,0 +1 @@ +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl_post/fl/clean.sh b/examples/fl_post/fl/clean.sh new file mode 100755 index 000000000..ce7879606 --- /dev/null +++ b/examples/fl_post/fl/clean.sh @@ -0,0 +1,9 @@ +rm -rf mlcube_agg/workspace/final_weights +rm -rf mlcube_agg/workspace/logs +rm -rf mlcube_col1/workspace/logs +rm -rf mlcube_col2/workspace/logs +rm -rf mlcube_col3/workspace/logs +rm -rf mlcube_agg/workspace/plan.yaml +rm -rf mlcube_col1/workspace/plan.yaml +rm -rf mlcube_col2/workspace/plan.yaml +rm -rf mlcube_col3/workspace/plan.yaml diff --git a/examples/fl_post/fl/csr.conf b/examples/fl_post/fl/csr.conf new file mode 100644 index 000000000..5ac85ae39 --- /dev/null +++ b/examples/fl_post/fl/csr.conf @@ -0,0 +1,23 @@ +[ req ] +default_bits = 3072 +prompt = no +default_md = sha384 +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[ req_distinguished_name ] +commonName = spr-gpu01.jf.intel.com + +[ req_ext ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names + +[ alt_names ] +DNS.1 = spr-gpu01.jf.intel.com + +[ v3_client ] +extendedKeyUsage = critical,clientAuth + +[ v3_server ] +extendedKeyUsage = critical,serverAuth diff --git a/examples/fl_post/fl/mlcube/mlcube.yaml b/examples/fl_post/fl/mlcube/mlcube.yaml new file mode 100644 index 000000000..39ecc21a9 --- /dev/null +++ b/examples/fl_post/fl/mlcube/mlcube.yaml @@ -0,0 +1,48 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + gpu_args: "--shm-size 12g" + # Image name + image: msheller/mlcube_testing:nnunet_fl_test + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + train: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + init_nnunet_directory: additional_files/init_nnunet/ + outputs: + output_logs: logs/ + start_aggregator: + parameters: + inputs: + input_weights: additional_files/init_weights + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + collaborators: cols.yaml + outputs: + output_logs: logs/ + output_weights: final_weights/ + report_path: { type: "file", default: "report/report.yaml" } + generate_plan: + parameters: + inputs: + training_config_path: training_config.yaml + aggregator_config_path: aggregator_config.yaml + outputs: + plan_path: { type: "file", default: "plan/plan.yaml" } diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml new file mode 100644 index 000000000..642f67575 --- /dev/null +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -0,0 +1,59 @@ +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/fl_post_two_init.pbuf + best_state_path : save/fl_post_two_best.pbuf + last_state_path : save/fl_post_two_last.pbuf + rounds_to_train : 10 + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : CONTINUE_LOCAL + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.nnunet_dummy_dataloader.NNUNetDummyDataLoader + settings : + p_train : 0.8 + +# TODO: make checkpoint-only truly generic and create the task runner within src +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner + settings : + device : cuda + gpu_num_string : '0' + nnunet_task : Task537_FLPost + +network : + defaults : plan/defaults/network.yaml + settings: {} + +assigner : + defaults : plan/defaults/assigner.yaml + template : openfl.component.RandomGroupedAssigner + settings : + task_groups : + - name : train_and_validate + percentage : 1.0 + tasks : + # - aggregated_model_validation + - train + # - locally_tuned_model_validation + +tasks : + defaults : plan/defaults/tasks_torch.yaml + train: + function : train + kwargs : + metrics : + - train_loss + - val_eval + epochs : 1 + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile new file mode 100644 index 000000000..984c4cb3b --- /dev/null +++ b/examples/fl_post/fl/project/Dockerfile @@ -0,0 +1,19 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 +ENV CUDA_VISIBLE_DEVICES="0" + + +# install project dependencies +RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +# RUN pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 +RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube project folder +COPY . /mlcube_project + +ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl_post/fl/project/README.md b/examples/fl_post/fl/project/README.md new file mode 100644 index 000000000..1e348651b --- /dev/null +++ b/examples/fl_post/fl/project/README.md @@ -0,0 +1,38 @@ +# How to configure container build for your application + +- List your pip requirements in `requirements.txt` +- List your software requirements in `Dockerfile` +- Modify the functions in `hooks.py` as needed. (Explanation TBD) + +# How to configure container for custom FL software + +- Change the base Docker image as needed. +- modify `aggregator.py` and `collaborator.py` as needed. Follow the implemented schema steps. + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd .. +bash build.sh +``` + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# How to customize + +TBD diff --git a/examples/fl_post/fl/project/aggregator.py b/examples/fl_post/fl/project/aggregator.py new file mode 100644 index 000000000..c0bbeafa1 --- /dev/null +++ b/examples/fl_post/fl/project/aggregator.py @@ -0,0 +1,60 @@ +from utils import ( + get_aggregator_fqdn, + prepare_node_cert, + prepare_ca_cert, + prepare_plan, + prepare_cols_list, + prepare_init_weights, + create_workspace, + get_weights_path, +) + +import os +import shutil +from subprocess import check_call +from distutils.dir_util import copy_tree + + +def start_aggregator( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, +): + + workspace_folder = os.path.join(output_logs, "workspace") + create_workspace(workspace_folder) + prepare_plan(plan_path, workspace_folder) + prepare_cols_list(collaborators, workspace_folder) + prepare_init_weights(input_weights, workspace_folder) + fqdn = get_aggregator_fqdn(workspace_folder) + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + check_call(["fx", "aggregator", "start"], cwd=workspace_folder) + + # TODO: check how to copy logs during runtime. + # perhaps investigate overriding plan entries? + + # NOTE: logs and weights are copied, even if target folders are not empty + copy_tree(os.path.join(workspace_folder, "logs"), output_logs) + + # NOTE: conversion fails since openfl needs sample data... + # weights_paths = get_weights_path(fl_workspace) + # out_best = os.path.join(output_weights, "best") + # out_last = os.path.join(output_weights, "last") + # check_call( + # ["fx", "model", "save", "-i", weights_paths["best"], "-o", out_best], + # cwd=workspace_folder, + # ) + copy_tree(os.path.join(workspace_folder, "save"), output_weights) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) + + with open(report_path, 'w') as f: + f.write("IsDone: 1") diff --git a/examples/fl_post/fl/project/collaborator.py b/examples/fl_post/fl/project/collaborator.py new file mode 100644 index 000000000..38c5048b6 --- /dev/null +++ b/examples/fl_post/fl/project/collaborator.py @@ -0,0 +1,32 @@ +from utils import ( + get_collaborator_cn, + prepare_node_cert, + prepare_ca_cert, + prepare_plan, + create_workspace, +) +import os +import shutil +from subprocess import check_call + + +def start_collaborator( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, +): + workspace_folder = os.path.join(output_logs, "workspace") + create_workspace(workspace_folder) + prepare_plan(plan_path, workspace_folder) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + # set log files + check_call(["fx", "collaborator", "start", "-n", cn], cwd=workspace_folder) + + # Cleanup + shutil.rmtree(workspace_folder, ignore_errors=True) diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py new file mode 100644 index 000000000..dca2c0231 --- /dev/null +++ b/examples/fl_post/fl/project/hooks.py @@ -0,0 +1,109 @@ +import os +import shutil +import pandas as pd +from utils import get_collaborator_cn + + +def __modify_df(df): + # gandlf convention: labels columns could be "target", "label", "mask" + # subject id column is subjectid. data columns are Channel_0. + # Others could be scalars. # TODO + labels_columns = ["target", "label", "mask"] + data_columns = ["channel_0"] + subject_id_column = "subjectid" + for column in df.columns: + if column.lower() == subject_id_column: + continue + if column.lower() in labels_columns: + prepend_str = "labels/" + elif column.lower() in data_columns: + prepend_str = "data/" + else: + continue + + df[column] = prepend_str + df[column].astype(str) + + +def collaborator_pre_training_hook( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + init_nnunet_directory, +): + # runtime env vars should be set as early as possible + tmpfolder = os.path.join(output_logs, ".tmp") + os.environ["TMPDIR"] = tmpfolder + os.makedirs(tmpfolder, exist_ok=True) + os.environ["RESULTS_FOLDER"] = os.path.join(tmpfolder, "nnUNet_trained_models") + os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") + os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") + import nnunet_setup + + cn = get_collaborator_cn() + workspace_folder = os.path.join(output_logs, "workspace") + os.makedirs(workspace_folder, exist_ok=True) + + os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) + os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) + + nnunet_setup.main([workspace_folder], + 537, # FIXME: does this need to be set in any particular way? + f'{init_nnunet_directory}/model_initial_checkpoint.model', + f'{init_nnunet_directory}/model_initial_checkpoint.model.pkl', + 'FLPost', + .8, + 'by_subject_time_pair', + '3d_fullres', + 'nnUNetTrainerV2', + '0', + plans_identifier=None, + num_institutions=1, + cuda_device='0', + verbose=False) + + data_config = f"{cn},Task537_FLPost" + plan_folder = os.path.join(workspace_folder, "plan") + os.makedirs(plan_folder, exist_ok=True) + data_config_path = os.path.join(plan_folder, "data.yaml") + with open(data_config_path, "w") as f: + f.write(data_config) + shutil.copytree('/mlcube_project/src', os.path.join(workspace_folder, 'src')) + +def collaborator_post_training_hook( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, +): + pass + + +def aggregator_pre_training_hook( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, +): + pass + + +def aggregator_post_training_hook( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, +): + pass diff --git a/examples/fl_post/fl/project/mlcube.py b/examples/fl_post/fl/project/mlcube.py new file mode 100644 index 000000000..0fe02af13 --- /dev/null +++ b/examples/fl_post/fl/project/mlcube.py @@ -0,0 +1,128 @@ +"""MLCube handler file""" + +import os +import shutil +import typer +from collaborator import start_collaborator +from aggregator import start_aggregator +from plan import generate_plan +from hooks import ( + aggregator_pre_training_hook, + aggregator_post_training_hook, + collaborator_pre_training_hook, + collaborator_post_training_hook, +) + +app = typer.Typer() + + +def _setup(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + os.makedirs(tmp_folder, exist_ok=True) + # TODO: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmp_folder + + +def _teardown(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + +@app.command("train") +def train( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + output_logs: str = typer.Option(..., "--output_logs"), + init_nnunet_directory: str = typer.Option(..., "--init_nnunet_directory"), +): + _setup(output_logs) + collaborator_pre_training_hook( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + init_nnunet_directory=init_nnunet_directory, + ) + start_collaborator( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + ) + collaborator_post_training_hook( + data_path=data_path, + labels_path=labels_path, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + plan_path=plan_path, + output_logs=output_logs, + ) + _teardown(output_logs) + + +@app.command("start_aggregator") +def start_aggregator_( + input_weights: str = typer.Option(..., "--input_weights"), + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + output_logs: str = typer.Option(..., "--output_logs"), + output_weights: str = typer.Option(..., "--output_weights"), + plan_path: str = typer.Option(..., "--plan_path"), + collaborators: str = typer.Option(..., "--collaborators"), + report_path: str = typer.Option(..., "--report_path"), +): + _setup(output_logs) + aggregator_pre_training_hook( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + ) + start_aggregator( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + ) + aggregator_post_training_hook( + input_weights=input_weights, + node_cert_folder=node_cert_folder, + ca_cert_folder=ca_cert_folder, + output_logs=output_logs, + output_weights=output_weights, + plan_path=plan_path, + collaborators=collaborators, + report_path=report_path, + ) + _teardown(output_logs) + + +@app.command("generate_plan") +def generate_plan_( + training_config_path: str = typer.Option(..., "--training_config_path"), + aggregator_config_path: str = typer.Option(..., "--aggregator_config_path"), + plan_path: str = typer.Option(..., "--plan_path"), +): + # no _setup here since there is no writable output mounted volume. + # later if need this we think of a solution. Currently the create_plam + # logic is assumed to not write within the container. + generate_plan(training_config_path, aggregator_config_path, plan_path) + + +if __name__ == "__main__": + app() diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py new file mode 100644 index 000000000..c7647f448 --- /dev/null +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -0,0 +1,417 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Brandon Edwards, Micah Sheller + +""" +import os +import subprocess +import pickle as pkl +import shutil +import numpy as np + +from collections import OrderedDict + +from nnunet.dataset_conversion.utils import generate_dataset_json + +from nnunet_model_setup import trim_data_and_setup_nnunet_models + + +num_to_modality = {'_0000': '_brain_t1n.nii.gz', + '_0001': '_brain_t2w.nii.gz', + '_0002': '_brain_t1c.nii.gz', + '_0003': '_brain_t2f.nii.gz'} + + +def get_subdirs(parent_directory): + subjects = os.listdir(parent_directory) + # print("before filter:", subjects) + subjects = [p for p in subjects if os.path.isdir(os.path.join(parent_directory, p)) and not p.startswith(".")] + # print("after filter:", subjects) + return subjects + + +def subject_time_to_mask_path(pardir, subject, timestamp): + mask_fname = f'{subject}_{timestamp}_tumorMask_model_0.nii.gz' + return os.path.join(pardir, 'labels', '.tumor_segmentation_backup', subject, timestamp,'TumorMasksForQC', mask_fname) + + +def create_task_folders(first_three_digit_task_num, num_institutions, task_name): + """ + Creates task folders for all simulated instiutions in the federation + """ + nnunet_dst_pardirs = [] + nnunet_images_train_pardirs = [] + nnunet_labels_train_pardirs = [] + + task_nums = range(first_three_digit_task_num, first_three_digit_task_num + num_institutions) + tasks = [f'Task{str(num)}_{task_name}' for num in task_nums] + for task in tasks: + + # The NNUnet data path is obtained from an environmental variable + nnunet_dst_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data', f'{task}') + + nnunet_images_train_pardir = os.path.join(nnunet_dst_pardir, 'imagesTr') + nnunet_labels_train_pardir = os.path.join(nnunet_dst_pardir, 'labelsTr') + + if os.path.exists(nnunet_images_train_pardir) and os.path.exists(nnunet_labels_train_pardir): + raise ValueError(f"Train images pardirs: {nnunet_images_train_pardir} and {nnunet_labels_train_pardir} both already exist. Please move them both and rerun to prevent overwriting.") + elif os.path.exists(nnunet_images_train_pardir): + raise ValueError(f"Train images pardir: {nnunet_images_train_pardir} already exists, please move and run again to prevent overwriting.") + elif os.path.exists(nnunet_labels_train_pardir): + raise ValueError(f"Train labels pardir: {nnunet_labels_train_pardir} already exists, please move and run again to prevent overwriting.") + + os.makedirs(nnunet_images_train_pardir, exist_ok=False) + os.makedirs(nnunet_labels_train_pardir, exist_ok=False) + + nnunet_dst_pardirs.append(nnunet_dst_pardir) + nnunet_images_train_pardirs.append(nnunet_images_train_pardir) + nnunet_labels_train_pardirs.append(nnunet_labels_train_pardir) + + return task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs + + +def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_labels_dirpath, nnunet_images_train_pardir, nnunet_labels_train_pardir, timestamp_selection): + postopp_subject_dirpath = os.path.join(postopp_data_dirpath, postopp_subject_dir) + all_timestamps = get_subdirs(postopp_subject_dirpath) + if timestamp_selection == 'latest': + timestamps = all_timestamps[-1:] + elif timestamp_selection == 'earliest': + timestamps = all_timestamps[0:1] + elif timestamp_selection == 'all': + timestamps = all_timestamps + else: + raise ValueError(f"timestamp_selection currently only supports 'latest', 'earliest', and 'all', but you have requested: '{timestamp_selection}'") + + for timestamp in timestamps: + postopp_subject_timestamp_dirpath = os.path.join(postopp_subject_dirpath, timestamp) + postopp_subject_timestamp_label_dirpath = os.path.join(postopp_labels_dirpath, postopp_subject_dir, timestamp) + if not os.path.exists(postopp_subject_timestamp_label_dirpath): + raise ValueError(f"Subject label file for data at: {postopp_subject_timestamp_dirpath} was not found in the expected location: {postopp_subject_timestamp_label_dirpath}") + + timed_subject = postopp_subject_dir + '_' + timestamp + + # Symlink label first + label_src_path = os.path.join(postopp_subject_timestamp_label_dirpath, timed_subject + '_final_seg.nii.gz') + label_dst_path = os.path.join(nnunet_labels_train_pardir, timed_subject + '.nii.gz') + os.symlink(src=label_src_path, dst=label_dst_path) + + # Symlink images + for num in num_to_modality: + src_path = os.path.join(postopp_subject_timestamp_dirpath, timed_subject + num_to_modality[num]) + dst_path = os.path.join(nnunet_images_train_pardir,timed_subject + num + '.nii.gz') + os.symlink(src=src_path, dst=dst_path) + + return timestamps + + +def doublecheck_postopp_pardir(postopp_pardir, verbose=False): + if verbose: + print(f"Checking postopp_pardir: {postopp_pardir}") + postopp_subdirs = get_subdirs(postopp_pardir) + if 'data' not in postopp_subdirs: + raise ValueError(f"'data' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") + if 'labels' not in postopp_subdirs: + raise ValueError(f"'labels' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") + + +def split_by_subject(subject_to_timestamps, percent_train, verbose=False): + """ + NOTE: An attempt is made to put percent_train of the total subjects into train (as opposed to val) regardless of how many timestamps there are for each subject. + No subject is allowed to have samples in both train and val. + """ + + subjects = list(subject_to_timestamps.keys()) + np.random.shuffle(subjects) + + train_cutoff = int(len(subjects) * percent_train) + + train_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in subjects[:train_cutoff] } + val_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in subjects[train_cutoff:]} + + return train_subject_to_timestamps, val_subject_to_timestamps + + +def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30, verbose=False): + """ + NOTE: An attempt is made to put percent_train of the subject timestamp combinations into train (as opposed to val) regardless of what that does to the subject ratios. + No subject is allowed to have samples in both train and val. + """ + def percent_train_for_split(train_subjects, grand_total): + sub_total = 0 + for subject in train_subjects: + sub_total += subject_counts[subject] + return sub_total/grand_total + + def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): + subjects = list(subject_counts.keys()) + np.random.shuffle(subjects) + for idx in range(2,len(subjects)+1): + train_subjects = subjects[:idx-1] + val_subjects = subjects[idx-1:] + percent_train_estimate = percent_train_for_split(train_subjects=train_subjects, grand_total=grand_total) + if percent_train_estimate >= percent_train: + """ + if verbose: + print(f"SPLIT COMPUTE - Found one split with percent_train of: {percent_train_estimate}") + """ + break + return train_subjects, val_subjects, percent_train_estimate + # above should return by end of loop as percent_train_estimate should be strictly increasing with final value 1.0 + + + subject_counts = {subject: len(subject_to_timestamps[subject]) for subject in subject_to_timestamps} + subjects_copy = list(subject_counts.keys()).copy() + grand_total = 0 + for subject in subject_counts: + grand_total += subject_counts[subject] + + # create a valid split of counts for comparison + best_train_subjects = subjects_copy[:1] + best_val_subjects = subjects_copy[1:] + best_percent_train = percent_train_for_split(train_subjects=best_train_subjects, grand_total=grand_total) + + # random shuffle times in order to find the closest we can get to honoring the percent_train requirement (train and val both need to be non-empty) + for _ in range(random_tries): + train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, verbose=verbose) + if abs(percent_train_estimate - percent_train) < abs(best_percent_train - percent_train): + best_train_subjects = train_subjects + best_val_subjects = val_subjects + best_percent_train = percent_train_estimate + if verbose: + print(f"\n#########\n Split was performed by timed subject and an error of {abs(best_percent_train - percent_train)} was acheived in the percent train target.") + train_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in best_train_subjects} + val_subject_to_timestamps = {subject: subject_to_timestamps[subject] for subject in best_val_subjects} + return train_subject_to_timestamps, val_subject_to_timestamps + + +def write_splits_file(nnunet_dst_pardir, subject_to_timestamps, percent_train, split_logic, fold, task, splits_fname='splits_final.pkl', verbose=False): + # double check we are in the right folder to modify the splits file + splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', splits_fname) + + # now split + if split_logic == 'by_subject': + train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose) + elif split_logic == 'by_subject_time_pair': + train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose) + else: + raise ValueError(f"Split logic of 'by_subject' and 'by_subject_time_pair' are the only ones supported, whereas a split_logic value of {split_logic} was provided.") + + # Now construct the list of subjects + train_subjects_list = [] + val_subjects_list = [] + for subject in train_subject_to_timestamps: + for timestamp in train_subject_to_timestamps[subject]: + train_subjects_list.append(subject + '_' + timestamp) + for subject in val_subject_to_timestamps: + for timestamp in val_subject_to_timestamps[subject]: + val_subjects_list.append(subject + '_' + timestamp) + + # Now write the splits file (note None is put into the folds that we don't use as a safety measure so that no unintended folds are used) + new_folds = [None, None, None, None, None] + new_folds[int(fold)] = OrderedDict({'train': np.array(train_subjects_list), 'val': np.array(val_subjects_list)}) + with open(splits_fpath, 'wb') as f: + pkl.dump(new_folds, f) + + +def setup_nnunet_data(postopp_pardirs, + first_three_digit_task_num, + task_name, + percent_train, + split_logic, + fold, + timestamp_selection, + num_institutions, + network, + network_trainer, + plans_identifier, + init_model_path, + init_model_info_path, + cuda_device, + verbose=False): + """ + Generates symlinks to be used for NNUnet training, assuming we already have a + dataset on file coming from MLCommons RANO experiment data prep. + + Also creates the json file for the data, as well as runs nnunet preprocessing. + + should be run using a virtual environment that has nnunet version 1 installed. + + args: + postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be + equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, + the samples within that single directory will be used to create num_insititutions shards. + If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. + Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: + ├── data + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ ├── AAAC_0_2008.03.30_brain_t1c.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t1n.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t2f.nii.gz + │ │ │ └── AAAC_0_2008.03.30_brain_t2w.nii.gz + │ │ └── 2008.12.17 + │ │ ├── AAAC_0_2008.12.17_brain_t1c.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t1n.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t2f.nii.gz + │ │ └── AAAC_0_2008.12.17_brain_t2w.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1c.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1n.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t2f.nii.gz + │ │ │ └── AAAC_1_2008.03.30_duplicate_brain_t2w.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1c.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1n.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t2f.nii.gz + │ │ └── AAAC_1_2008.12.17_duplicate_brain_t2w.nii.gz + │ ├── AAAC_extra + │ │ └── 2008.12.10 + │ │ ├── AAAC_extra_2008.12.10_brain_t1c.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t1n.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t2f.nii.gz + │ │ └── AAAC_extra_2008.12.10_brain_t2w.nii.gz + │ ├── data.csv + │ └── splits.csv + ├── labels + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ └── AAAC_0_2008.03.30_final_seg.nii.gz + │ │ └── 2008.12.17 + │ │ └── AAAC_0_2008.12.17_final_seg.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ └── AAAC_1_2008.03.30_duplicate_final_seg.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ └── AAAC_1_2008.12.17_duplicate_final_seg.nii.gz + │ └── AAAC_extra + │ └── 2008.12.10 + │ └── AAAC_extra_2008.12.10_final_seg.nii.gz + └── report.yaml + + first_three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. + task_name(str) : Any string task name. + timestamp_selection(str) : Indicates how to determine the timestamp to pick + for each subject ID at the source: 'latest', 'earliest', and 'all' are the only ones supported so far + percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since + all timestamps associated with the same subject need to land exclusively in either train or val. + split_logic(str) : Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair'. + fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' + num_institutions(int) : Number of simulated institutions to shard the data into. + verbose(bool) : Debugging output if True. + + Returns: + task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs + """ + + task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs = \ + create_task_folders(first_three_digit_task_num=first_three_digit_task_num, + num_institutions=num_institutions, + task_name=task_name) + + if len(postopp_pardirs) == 1: + postopp_pardir = postopp_pardirs[0] + doublecheck_postopp_pardir(postopp_pardir, verbose=verbose) + postopp_data_dirpaths = num_institutions * [os.path.join(postopp_pardir, 'data')] + postopp_labels_dirpaths = num_institutions * [os.path.join(postopp_pardir, 'labels')] + + all_subjects = get_subdirs(postopp_data_dirpaths[0]) + subject_shards = [all_subjects[start::num_institutions] for start in range(num_institutions)] + elif len(postopp_pardirs) != num_institutions: + raise ValueError(f"The length of postopp_pardirs must be equal to the number of insitutions needed for the federation, or can be length one and an automated split is peroformed.") + else: + subject_shards = [] + postopp_data_dirpaths = [] + postopp_labels_dirpaths = [] + for postopp_pardir in postopp_pardirs: + doublecheck_postopp_pardir(postopp_pardir, verbose=verbose) + postopp_data_dirpath = os.path.join(postopp_pardir, 'data') + postopp_labels_dirpath = os.path.join(postopp_pardir, 'labels') + postopp_data_dirpaths.append(postopp_data_dirpath) + postopp_labels_dirpaths.append(postopp_labels_dirpath) + subject_shards.append(get_subdirs(postopp_labels_dirpath)) + + # Track the subjects and timestamps for each shard + shard_subject_to_timestamps = [] + + for shard_idx, (postopp_subject_dirs, task_num, task, nnunet_dst_pardir, nnunet_images_train_pardir, nnunet_labels_train_pardir, postopp_data_dirpath, postopp_labels_dirpath) in \ + enumerate(zip(subject_shards, task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs, postopp_data_dirpaths, postopp_labels_dirpaths)): + print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA FOR COLLABORATOR {shard_idx} #########\n") + subject_to_timestamps = {} + for postopp_subject_dir in postopp_subject_dirs: + subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, + postopp_data_dirpath=postopp_data_dirpath, + postopp_labels_dirpath=postopp_labels_dirpath, + nnunet_images_train_pardir=nnunet_images_train_pardir, + nnunet_labels_train_pardir=nnunet_labels_train_pardir, + timestamp_selection=timestamp_selection) + shard_subject_to_timestamps.append(subject_to_timestamps) + + # Generate json file for the dataset + print(f"\n######### GENERATING DATA JSON FILE FOR COLLABORATOR {shard_idx} #########\n") + json_path = os.path.join(nnunet_dst_pardir, 'dataset.json') + labels = {0: 'Background', 1: 'Necrosis', 2: 'Edema', 3: 'Enhancing Tumor', 4: 'Cavity'} + # labels = {0: 'Background', 1: 'Necrosis', 2: 'Edema'} + # print(f"{nnunet_images_train_pardir}") + # print(list(os.listdir(nnunet_images_train_pardir))) + + + # from typing import List, Union + # def subfiles(folder: str, join: bool = True, prefix: Union[List[str], str] = None, + # suffix: Union[List[str], str] = None, sort: bool = True) -> List[str]: + # if join: + # l = os.path.join + # else: + # l = lambda x, y: y + + # if prefix is not None and isinstance(prefix, str): + # prefix = [prefix] + # if suffix is not None and isinstance(suffix, str): + # suffix = [suffix] + # print([ i for i in os.listdir(folder)]) + # print([ i for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) ]) + # res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) + # and (prefix is None or any([i.startswith(j) for j in prefix])) + # and (suffix is None or any([i.endswith(j) for j in suffix]))] + + # if sort: + # res.sort() + # return res + + + # uniques = np.unique([i[:-12] for i in subfiles(nnunet_images_train_pardir, suffix='.nii.gz', join=False)]) + # print("UNIQUES::::\n",uniques) + generate_dataset_json(output_file=json_path, + imagesTr_dir=nnunet_images_train_pardir, + imagesTs_dir=None, + modalities=tuple(num_to_modality.keys()), + labels=labels, + dataset_name='RANO Postopp') + + # Now call the os process to preprocess the data + print(f"\n######### OS CALL TO PREPROCESS DATA FOR COLLABORATOR {shard_idx} #########\n") + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{task_num}", "--verify_dataset_integrity"]) + + # trim 2d data if not working with 2d model, and distribute common model architecture across simulated collaborators + trim_data_and_setup_nnunet_models(tasks=tasks, + network=network, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold, + init_model_path=init_model_path, + init_model_info_path=init_model_info_path, + cuda_device=cuda_device) + + + for task, subject_to_timestamps in zip(tasks, shard_subject_to_timestamps): + # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val + write_splits_file(nnunet_dst_pardir=nnunet_dst_pardir, + subject_to_timestamps=subject_to_timestamps, + percent_train=percent_train, + split_logic=split_logic, + fold=fold, + task=task, + verbose=verbose) + diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py new file mode 100644 index 000000000..30d550bdb --- /dev/null +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -0,0 +1,181 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Brandon Edwards, Micah Sheller + +""" + +import os +import pickle as pkl +import shutil + +# from nnunet_v1 import train_nnunet + + +def train_on_task(task, network, network_trainer, fold, cuda_device, continue_training=False, current_epoch=0): + os.environ['CUDA_VISIBLE_DEVICES']=cuda_device + print(f"###########\nStarting training for task: {task}\n") + train_nnunet(epochs=1, + current_epoch = current_epoch, + network = network, + task=task, + network_trainer = network_trainer, + fold=fold, + continue_training=continue_training) + + +def model_folder(network, task, network_trainer, plans_identifier, fold, results_folder=os.environ['RESULTS_FOLDER']): + return os.path.join(results_folder, 'nnUNet',network, task, network_trainer + '__' + plans_identifier, f'fold_{fold}') + + +def model_paths_from_folder(model_folder): + return {'initial_model_path': os.path.join(model_folder, 'model_initial_checkpoint.model'), + 'final_model_path': os.path.join(model_folder, 'model_final_checkpoint.model'), + 'initial_model_info_path': os.path.join(model_folder, 'model_initial_checkpoint.model.pkl'), + 'final_model_info_path': os.path.join(model_folder, 'model_final_checkpoint.model.pkl')} + + +def plan_path(network, task, plans_identifier): + preprocessed_path = os.environ['nnUNet_preprocessed'] + plan_dirpath = os.path.join(preprocessed_path, task) + plan_path_2d = os.path.join(plan_dirpath, plans_identifier + "_plans_2D.pkl") + plan_path_3d = os.path.join(plan_dirpath, plans_identifier + "_plans_3D.pkl") + + if network =='2d': + plan_path = plan_path_2d + else: + plan_path = plan_path_3d + + return plan_path + +def delete_2d_data(network, task, plans_identifier): + if network == '2d': + raise ValueError(f"2D data should not be deleted when performing 2d training.") + else: + preprocessed_path = os.environ['nnUNet_preprocessed'] + plan_dirpath = os.path.join(preprocessed_path, task) + plan_path_2d = os.path.join(plan_dirpath, plans_identifier + "_plans_2D.pkl") + + if os.path.exists(plan_path_2d): + # load 2d plan to help construct 2D data directory + with open(plan_path_2d, 'rb') as _file: + plan_2d = pkl.load(_file) + data_dir_2d = os.path.join(plan_dirpath, plan_2d['data_identifier'] + '_stage' + str(list(plan_2d['plans_per_stage'].keys())[-1])) + print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") + shutil.rmtree(data_dir_2d) + + +def normalize_architecture(reference_plan_path, target_plan_path): + """ + Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path + + NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more + likely leed to an exception. + + """ + + assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features', + 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', + 'transpose_backward', 'preprocessor_name', 'conv_per_stage', 'data_identifier'] + copy_over_keys = ['plans_per_stage'] + nullify_keys = ['original_spacings', 'original_sizes'] + leave_alone_keys = ['list_of_npz_files', 'preprocessed_data_folder', 'dataset_properties'] + + + # check I got all keys here + assert set(copy_over_keys).union(set(assert_same_keys)).union(set(nullify_keys)).union(set(leave_alone_keys)) == set(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage']) + + def get_pickle_obj(path): + with open(path, 'rb') as _file: + plan= pkl.load(_file) + return plan + + def write_pickled_obj(obj, path): + with open(path, 'wb') as _file: + pkl.dump(obj, _file) + + reference_plan = get_pickle_obj(path=reference_plan_path) + target_plan = get_pickle_obj(path=target_plan_path) + + for key in assert_same_keys: + if reference_plan[key] != target_plan[key]: + raise ValueError(f"normalize architecture failed since the reference and target plans differed in at least key: {key}") + for key in copy_over_keys: + target_plan[key] = reference_plan[key] + for key in nullify_keys: + target_plan[key] = None + # leave alone keys are left alone :) + + # write back to target plan + write_pickled_obj(obj=target_plan, path=target_plan_path) + + +def trim_data_and_setup_nnunet_models(tasks, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, cuda_device='0'): + + col_0_task = tasks[0] + # trim collaborator 0 data if appropriate + if network != '2d': + delete_2d_data(network=network, task=col_0_task, plans_identifier=plans_identifier) + # get the architecture info from the first collaborator 0 data setup results, and create its model folder (writing the initial model info into it) + col_0_plan_path = plan_path(network=network, task=col_0_task, plans_identifier=plans_identifier) + + col_0_model_folder = model_folder(network=network, + task=col_0_task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold) + os.makedirs(col_0_model_folder, exist_ok=False) + + col_0_model_files_dict = model_paths_from_folder(model_folder=model_folder(network=network, + task=col_0_task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold)) + if not init_model_path: + # train collaborator 0 for a single epoch to get an initial model + train_on_task(task=col_0_task, network=network, network_trainer=network_trainer, fold=fold, cuda_device=cuda_device) + # now copy the final model and info from the initial training run into the initial paths + shutil.copyfile(src=col_0_model_files_dict['final_model_path'],dst=col_0_model_files_dict['initial_model_path']) + shutil.copyfile(src=col_0_model_files_dict['final_model_info_path'],dst=col_0_model_files_dict['initial_model_info_path']) + else: + print(f"\n######### COPYING INITIAL MODEL FILES INTO COLLABORATOR 0 FOLDERS #########\n") + # Copy initial model and model info into col_0_model_folder + shutil.copyfile(src=init_model_path,dst=col_0_model_files_dict['initial_model_path']) + shutil.copyfile(src=init_model_info_path,dst=col_0_model_files_dict['initial_model_info_path']) + # now copy the initial model also into the final paths + shutil.copyfile(src=col_0_model_files_dict['initial_model_path'],dst=col_0_model_files_dict['final_model_path']) + shutil.copyfile(src=col_0_model_files_dict['initial_model_info_path'],dst=col_0_model_files_dict['final_model_info_path']) + + # now create the model folders for collaborators 1 and upward, populate them with the model files from 0, + # and replace their data directory plan files from the col_0 plan + for col_idx_minus_one, task in enumerate(tasks[1:]): + # trim collaborator data if appropriate + if network != '2d': + delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) + + print(f"\n######### COPYING MODEL INFO FROM COLLABORATOR 0 TO COLLABORATOR {col_idx_minus_one + 1} #########\n") + # replace data directory plan file with one from col_0 + target_plan_path = plan_path(network=network, task=task, plans_identifier=plans_identifier) + normalize_architecture(reference_plan_path=col_0_plan_path, target_plan_path=target_plan_path) + + # create model folder for this collaborator + this_col_model_folder = model_folder(network=network, + task=task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold) + os.makedirs(this_col_model_folder, exist_ok=False) + + # copy model, and model info files from col_0 to this collaborator's model folder + this_col_model_files_dict = model_paths_from_folder(model_folder=model_folder(network=network, + task=task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold)) + # Copy initial and final model and model info from col_0 into this_col_model_folder + shutil.copyfile(src=col_0_model_files_dict['initial_model_path'],dst=this_col_model_files_dict['initial_model_path']) + shutil.copyfile(src=col_0_model_files_dict['final_model_path'],dst=this_col_model_files_dict['final_model_path']) + shutil.copyfile(src=col_0_model_files_dict['initial_model_info_path'],dst=this_col_model_files_dict['initial_model_info_path']) + shutil.copyfile(src=col_0_model_files_dict['final_model_info_path'],dst=this_col_model_files_dict['final_model_info_path']) + \ No newline at end of file diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py new file mode 100644 index 000000000..cdc54d9e6 --- /dev/null +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -0,0 +1,228 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Brandon Edwards, Micah Sheller + +""" + +import argparse + +from nnunet.paths import default_plans_identifier + +from nnunet_data_setup import setup_nnunet_data +from nnunet_model_setup import trim_data_and_setup_nnunet_models + +def list_of_strings(arg): + return arg.split(',') + +def main(postopp_pardirs, + first_three_digit_task_num, + init_model_path, + init_model_info_path, + task_name, + percent_train, + split_logic, + network, + network_trainer, + fold, + plans_identifier=None, + timestamp_selection='all', + num_institutions=1, + cuda_device='0', + verbose=False): + """ + Generates symlinks to be used for NNUnet training, assuming we already have a + dataset on file coming from MLCommons RANO experiment data prep. + + Also creates the json file for the data, as well as runs nnunet preprocessing. + + should be run using a virtual environment that has nnunet version 1 installed. + + args: + postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be + equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, + the samples within that single directory will be used to create num_insititutions shards. + If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. + Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: + ├── data + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ ├── AAAC_0_2008.03.30_brain_t1c.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t1n.nii.gz + │ │ │ ├── AAAC_0_2008.03.30_brain_t2f.nii.gz + │ │ │ └── AAAC_0_2008.03.30_brain_t2w.nii.gz + │ │ └── 2008.12.17 + │ │ ├── AAAC_0_2008.12.17_brain_t1c.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t1n.nii.gz + │ │ ├── AAAC_0_2008.12.17_brain_t2f.nii.gz + │ │ └── AAAC_0_2008.12.17_brain_t2w.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1c.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t1n.nii.gz + │ │ │ ├── AAAC_1_2008.03.30_duplicate_brain_t2f.nii.gz + │ │ │ └── AAAC_1_2008.03.30_duplicate_brain_t2w.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1c.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t1n.nii.gz + │ │ ├── AAAC_1_2008.12.17_duplicate_brain_t2f.nii.gz + │ │ └── AAAC_1_2008.12.17_duplicate_brain_t2w.nii.gz + │ ├── AAAC_extra + │ │ └── 2008.12.10 + │ │ ├── AAAC_extra_2008.12.10_brain_t1c.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t1n.nii.gz + │ │ ├── AAAC_extra_2008.12.10_brain_t2f.nii.gz + │ │ └── AAAC_extra_2008.12.10_brain_t2w.nii.gz + │ ├── data.csv + │ └── splits.csv + ├── labels + │ ├── AAAC_0 + │ │ ├── 2008.03.30 + │ │ │ └── AAAC_0_2008.03.30_final_seg.nii.gz + │ │ └── 2008.12.17 + │ │ └── AAAC_0_2008.12.17_final_seg.nii.gz + │ ├── AAAC_1 + │ │ ├── 2008.03.30_duplicate + │ │ │ └── AAAC_1_2008.03.30_duplicate_final_seg.nii.gz + │ │ └── 2008.12.17_duplicate + │ │ └── AAAC_1_2008.12.17_duplicate_final_seg.nii.gz + │ └── AAAC_extra + │ └── 2008.12.10 + │ └── AAAC_extra_2008.12.10_final_seg.nii.gz + └── report.yaml + + first_three_digit_task_num(str) : Should start with '5'. If nnunet == N, all N task numbers starting with this number will be used. + init_model_path (str) : path to initial (pretrained) model file [default None] - must be provided if init_model_info_path is. + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0 DATA.] + init_model_info_path(str) : path to initial (pretrained) model info pikle file [default None]- must be provided if init_model_path is. + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0 DATA.] + task_name(str) : Name of task that is part of the task name + percent_train(float) : The percentage of samples to split into the train portion for the fold specified below (NNUnet makes its own folds but we overwrite + all with None except the fold indicated below and put in our own split instead determined by a hard coded split logic default) + split_logic(str) : Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair' (see inner function docstring) + network(str) : NNUnet network to be used + network_trainer(str) : NNUnet network trainer to be used + fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' + plans_identifier(str) : Used in the plans file naming. + task_name(str) : Any string task name. + timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. + for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far + num_institutions(int) : Number of simulated institutions to shard the data into. + verbose(bool) : If True, print debugging information. + """ + if plans_identifier is None: + plans_identifier = default_plans_identifier + + + # some argument inspection + task_digit_length = len(str(first_three_digit_task_num)) + if task_digit_length != 3: + raise ValueError(f'The number of digits in {first_three_digit_task_num} should be 3, but it is: {task_digit_length} instead.') + if str(first_three_digit_task_num)[0] != '5': + raise ValueError(f"The three digit task number: {first_three_digit_task_num} should start with 5 to avoid NNUnet repository tasks, but it starts with {first_three_digit_task_num[0]}") + if init_model_path or init_model_info_path: + if not init_model_path or not init_model_info_path: + raise ValueError(f"If either init_model_path or init_model_info_path are provided, they both must be.") + if init_model_path: + if not init_model_path.endswith('.model'): + raise ValueError(f"Initial model file should end with, '.model'") + if not init_model_info_path.endswith('.model.pkl'): + raise ValueError(f"Initial model info file should end with, 'model.pkl'") + + + + # task_folder_info is a zipped lists indexed over tasks (collaborators) + # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) + setup_nnunet_data(postopp_pardirs=postopp_pardirs, + first_three_digit_task_num=first_three_digit_task_num, + task_name=task_name, + percent_train=percent_train, + split_logic=split_logic, + fold=fold, + timestamp_selection=timestamp_selection, + num_institutions=num_institutions, + network=network, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + init_model_path=init_model_path, + init_model_info_path=init_model_info_path, + cuda_device=cuda_device, + verbose=verbose) + +if __name__ == '__main__': + + argparser = argparse.ArgumentParser() + argparser.add_argument( + '--postopp_pardirs', + type=list_of_strings, + # nargs='+', + help="Parent directories to postopp data (all should have 'data' and 'labels' subdirectories). Length needs to equal num_institutions or be lengh 1.") + argparser.add_argument( + '--first_three_digit_task_num', + type=int, + help="Should start with '5'. If nnunet == N, all N task numbers starting with this number will be used.") + argparser.add_argument( + '--init_model_path', + type=str, + default=None, + help="Path to initial (pretrained) model file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0's DATA.].") + argparser.add_argument( + '--init_model_info_path', + type=str, + default=None, + help="Path to initial (pretrained) model info file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0's DATA.].") + argparser.add_argument( + '--task_name', + type=str, + help="Part of the NNUnet data task directory name. With 'first_three_digit_task_num being 'XXX', the directory name becomes: .../nnUNet_raw_data_base/nnUNet_raw_data/TaskXXX_.") + argparser.add_argument( + '--percent_train', + type=float, + default=0.8, + help="The percentage of samples to split into the train portion for the fold specified below (NNUnet makes its own folds but we overwrite) - see docstring in main") + argparser.add_argument( + '--split_logic', + type=str, + default='by_subject_time_pair', + help="Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair' (see inner function docstring)") + argparser.add_argument( + '--network', + type=str, + default='3d_fullres', + help="NNUnet network to be used.") + argparser.add_argument( + '--network_trainer', + type=str, + default='nnUNetTrainerV2', + help="NNUnet network trainer to be used.") + argparser.add_argument( + '--fold', + type=str, + default='0', + help="Fold to train on, can be a sting indicating an int, or can be 'all'.") + argparser.add_argument( + '--timestamp_selection', + type=str, + default='all', + help="Indicates how to determine the timestamp to pick for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far.") + argparser.add_argument( + '--num_institutions', + type=int, + default=1, + help="Number of symulated insitutions to shard the data into.") + argparser.add_argument( + '--cuda_device', + type=str, + default='0', + help="Used for the setting of os.environ['CUDA_VISIBLE_DEVICES']") + argparser.add_argument( + '--verbose', + action='store_true', + help="Print debuging information.") + + args = argparser.parse_args() + + kwargs = vars(args) + + main(**kwargs) diff --git a/examples/fl_post/fl/project/plan.py b/examples/fl_post/fl/project/plan.py new file mode 100644 index 000000000..2feb1bf52 --- /dev/null +++ b/examples/fl_post/fl/project/plan.py @@ -0,0 +1,16 @@ +import yaml + + +def generate_plan(training_config_path, aggregator_config_path, plan_path): + with open(training_config_path) as f: + training_config = yaml.safe_load(f) + with open(aggregator_config_path) as f: + aggregator_config = yaml.safe_load(f) + + # TODO: key checks. Also, define what should be considered aggregator_config + # (e.g., tls=true, reconnect_interval, ...) + training_config["network"]["settings"]["agg_addr"] = aggregator_config["address"] + training_config["network"]["settings"]["agg_port"] = aggregator_config["port"] + + with open(plan_path, "w") as f: + yaml.dump(training_config, f) diff --git a/examples/fl_post/fl/project/requirements.txt b/examples/fl_post/fl/project/requirements.txt new file mode 100644 index 000000000..c3eb2a404 --- /dev/null +++ b/examples/fl_post/fl/project/requirements.txt @@ -0,0 +1,3 @@ +onnx==1.13.0 +typer==0.9.0 +git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@main#egg=nnunet diff --git a/examples/fl_post/fl/project/src/__init__.py b/examples/fl_post/fl/project/src/__init__.py new file mode 100644 index 000000000..f1410b129 --- /dev/null +++ b/examples/fl_post/fl/project/src/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh b/examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh new file mode 100644 index 000000000..b85de2bb2 --- /dev/null +++ b/examples/fl_post/fl/project/src/launch_collaborator_with_env_vars.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=$1 fx collaborator start -n $2 \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py new file mode 100644 index 000000000..1fe83a4f5 --- /dev/null +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -0,0 +1,36 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Micah Sheller, Brandon Edwards + +""" + +"""You may copy this file as the starting point of your own model.""" + +import json +import os + +class NNUNetDummyDataLoader(): + def __init__(self, data_path, p_train): + self.task_name = data_path + data_base_path = os.path.join(os.environ['nnUNet_preprocessed'], self.task_name) + with open(f'{data_base_path}/dataset.json', 'r') as f: + data_config = json.load(f) + data_size = data_config['numTraining'] + + # TODO: determine how nnunet validation splits round + self.train_data_size = int(p_train * data_size) + self.valid_data_size = data_size - self.train_data_size + + def get_feature_shape(self): + return [1,1,1] + + def get_train_data_size(self): + return self.train_data_size + + def get_valid_data_size(self): + return self.valid_data_size + + def get_task_name(self): + return self.task_name diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py new file mode 100644 index 000000000..701d02226 --- /dev/null +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -0,0 +1,278 @@ + + + + +# The following was copied and modified from the source: +# https://github.com/kaapana/kaapana/blob/26d71920d53c3110e2494cbb2ddb0cbb996b880a/data-processing/base-images/base-nnunet/files/patched/run_training.py#L213 + + +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# EDITED for OpenFL test integration by Brandon Edwards and Micah Sheller + + +import os +import numpy as np +import torch +import random +from batchgenerators.utilities.file_and_folder_operations import * +from nnunet.run.default_configuration import get_default_configuration +from nnunet.paths import default_plans_identifier +from nnunet.run.load_pretrained_weights import load_pretrained_weights +from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage +from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer +from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import ( + nnUNetTrainerCascadeFullRes, +) +from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import ( + nnUNetTrainerV2CascadeFullRes, +) +from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name + +def seed_everything(seed=1234): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def train_nnunet(epochs, + current_epoch, + network='3d_fullres', + network_trainer='nnUNetTrainerV2', + task='Task543_FakePostOpp_More', + fold='0', + continue_training=True, + validation_only=False, + c=False, + p=default_plans_identifier, + use_compressed_data=False, + deterministic=False, + npz=False, + find_lr=False, + valbest=False, + fp32=False, + val_folder='validation_raw', + disable_saving=False, + disable_postprocessing_on_folds=True, + val_disable_overwrite=True, + disable_next_stage_pred=False, + pretrained_weights=None): + + """ + task (int): can be task name or task id + fold: "0, 1, ..., 5 or 'all'" + validation_only: use this if you want to only run the validation + c: use this if you want to continue a training + p: plans identifier. Only change this if you created a custom experiment planner + use_compressed_data: "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " + "is much more CPU and RAM intensive and should only be used if you know what you are " + "doing" + deterministic: "Makes training deterministic, but reduces training speed substantially. I (Fabian) think " + "this is not necessary. Deterministic training will make you overfit to some random seed. " + "Don't use that." + npz: "if set then nnUNet will " + "export npz files of " + "predicted segmentations " + "in the validation as well. " + "This is needed to run the " + "ensembling step so unless " + "you are developing nnUNet " + "you should enable this" + find_lr: not used here, just for fun + valbest: hands off. This is not intended to be used + fp32: disable mixed precision training and run old school fp32 + val_folder: name of the validation folder. No need to use this for most people + disable_saving: If set nnU-Net will not save any parameter files (except a temporary checkpoint that " + "will be removed at the end of the training). Useful for development when you are " + "only interested in the results and want to save some disk space + disable_postprocessing_on_folds: Running postprocessing on each fold only makes sense when developing with nnU-Net and " + "closely observing the model performance on specific configurations. You do not need it " + "when applying nnU-Net because the postprocessing for this will be determined only once " + "all five folds have been trained and nnUNet_find_best_configuration is called. Usually " + "running postprocessing on each fold is computationally cheap, but some users have " + "reported issues with very large images. If your images are large (>600x600x600 voxels) " + "you should consider setting this flag. + val_disable_overwrite: If True, validation does not overwrite existing segmentations + pretrained_wieghts: path to nnU-Net checkpoint file to be used as pretrained model (use .model " + "file, for example model_final_checkpoint.model). Will only be used when actually training. " + "Optional. Beta. Use with caution." + disable_next_stage_pred: If True, do not predict next stage + """ + + class Arguments(): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + args = Arguments(**locals()) + + if args.deterministic: + seed_everything() + + task = args.task + fold = args.fold + network = args.network + network_trainer = args.network_trainer + validation_only = args.validation_only + plans_identifier = args.p + find_lr = args.find_lr + disable_postprocessing_on_folds = args.disable_postprocessing_on_folds + + use_compressed_data = args.use_compressed_data + decompress_data = not use_compressed_data + + deterministic = args.deterministic + valbest = args.valbest + + fp32 = args.fp32 + run_mixed_precision = not fp32 + + val_folder = args.val_folder + # interp_order = args.interp_order + # interp_order_z = args.interp_order_z + # force_separate_z = args.force_separate_z + + if not task.startswith("Task"): + task_id = int(task) + task = convert_id_to_task_name(task_id) + + if fold == "all": + pass + else: + fold = int(fold) + + # if force_separate_z == "None": + # force_separate_z = None + # elif force_separate_z == "False": + # force_separate_z = False + # elif force_separate_z == "True": + # force_separate_z = True + # else: + # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z) + + ( + plans_file, + output_folder_name, + dataset_directory, + batch_dice, + stage, + trainer_class, + ) = get_default_configuration(network, task, network_trainer, plans_identifier) + + if trainer_class is None: + raise RuntimeError( + "Could not find trainer class in nnunet.training.network_training" + ) + + if network == "3d_cascade_fullres": + assert issubclass( + trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes) + ), ( + "If running 3d_cascade_fullres then your " + "trainer class must be derived from " + "nnUNetTrainerCascadeFullRes" + ) + else: + assert issubclass( + trainer_class, nnUNetTrainer + ), "network_trainer was found but is not derived from nnUNetTrainer" + + trainer = trainer_class( + plans_file, + fold, + output_folder=output_folder_name, + dataset_directory=dataset_directory, + batch_dice=batch_dice, + stage=stage, + unpack_data=decompress_data, + deterministic=deterministic, + fp16=run_mixed_precision, + ) + # we want latest checkoint only (not best or any intermediate) + trainer.save_final_checkpoint = ( + True # whether or not to save the final checkpoint + ) + trainer.save_best_checkpoint = ( + False # whether or not to save the best checkpoint according to + ) + # self.best_val_eval_criterion_MA + trainer.save_intermediate_checkpoints = ( + False # whether or not to save checkpoint_latest. We need that in case + ) + # the training chashes + trainer.save_latest_only = ( + True # if false it will not store/overwrite _latest but separate files each + ) + trainer.max_num_epochs = current_epoch + epochs + trainer.epoch = current_epoch + + # TODO: call validation separately + trainer.initialize(not validation_only) + + if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": + trainer.save_checkpoint( + join(trainer.output_folder, "model_final_checkpoint.model") + ) + print("Preparation round: Model-averaging") + return + + if find_lr: + trainer.find_lr() + else: + if not validation_only: + if args.continue_training: + # -c was set, continue a previous training and ignore pretrained weights + trainer.load_latest_checkpoint() + elif (not args.continue_training) and (args.pretrained_weights is not None): + # we start a new training. If pretrained_weights are set, use them + load_pretrained_weights(trainer.network, args.pretrained_weights) + else: + # new training without pretraine weights, do nothing + pass + + trainer.run_training() + else: + # if valbest: + # trainer.load_best_checkpoint(train=False) + # else: + # trainer.load_final_checkpoint(train=False) + trainer.load_latest_checkpoint() + + trainer.network.eval() + + # if fold == "all": + # print("--> fold == 'all'") + # print("--> DONE") + # else: + # # predict validation + # trainer.validate( + # save_softmax=args.npz, + # validation_folder_name=val_folder, + # run_postprocessing_on_folds=not disable_postprocessing_on_folds, + # overwrite=args.val_disable_overwrite, + # ) + + # if network == "3d_lowres" and not args.disable_next_stage_pred: + # print("predicting segmentations for the next stage of the cascade") + # predict_next_stage( + # trainer, + # join( + # dataset_directory, + # trainer.plans["data_identifier"] + "_stage%d" % 1, + # ), + # ) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py new file mode 100644 index 000000000..ecc3869eb --- /dev/null +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -0,0 +1,233 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? + +""" +# TODO: Clean up imports + +import os +import subprocess +import shutil +import time +import pickle as pkl +from copy import deepcopy +import hashlib +import yaml + +import numpy as np +import torch + +from openfl.utilities import TensorKey +from openfl.utilities.split import split_tensor_dict_for_holdouts + +from openfl.federated.task.runner_pt_utils import rebuild_model_util, derive_opt_state_dict, expand_derived_opt_state_dict +from openfl.federated.task.runner_pt_utils import initialize_tensorkeys_for_functions_util, to_cpu_numpy +from openfl.federated.task.nnunet_v1 import train_nnunet + +from .runner_pt_chkpt import PyTorchCheckpointTaskRunner +from .nnunet_v1 import train_nnunet + +class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): + """An abstract class for PyTorch model based Tasks, where training, validation etc. are processes that + pull model state from a PyTorch checkpoint.""" + + def __init__(self, + nnunet_task=None, + config_path=None, + **kwargs): + """Initialize. + + Args: + config_path(str) : Path to the configuration file used by the training and validation script. + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + TODO: + """ + + if 'nnUNet_raw_data_base' not in os.environ: + raise ValueError("NNUNet V1 requires that 'nnUNet_raw_data_base' be set either in the flplan or in the environment variables") + if 'nnUNet_preprocessed' not in os.environ: + raise ValueError("NNUNet V1 requires that 'nnUNet_preprocessed' be set either in the flplan or in the environment variables") + if 'RESULTS_FOLDER' not in os.environ: + raise ValueError("NNUNet V1 requires that 'RESULTS_FOLDER' be set either in the flplan or in the environment variables") + + super().__init__( + checkpoint_path_initial=os.path.join( + os.environ['RESULTS_FOLDER'], + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/', + 'model_initial_checkpoint.model' + ), + checkpoint_path_save=os.path.join( + os.environ['RESULTS_FOLDER'], + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/', + 'model_final_checkpoint.model' + ), + checkpoint_path_load=os.path.join( + os.environ['RESULTS_FOLDER'], + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/', + 'model_final_checkpoint.model' + ), + **kwargs, + ) + + self.config_path = config_path + + + def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): + """ + Save model state in tensor_dict to in a pickle file at self.checkpoint_out_path. Uses pt.save(). + All state in the checkpoint other than the model state will be kept as is in the file. + Note: Utilization of a with_opt_vars input will be needed (along with saving an initial state optimizer state on disk), + will be needed if a self.opt_treatement of 'RESET' or 'AGG' are to be used + + Here is an example of a dictionary NNUnet uses for its state: + save_this = + { + 'epoch': self.epoch + 1, + 'state_dict': state_dict, + 'optimizer_state_dict': optimizer_state_dict, + 'lr_scheduler_state_dict': lr_sched_state_dct, + 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, + self.all_val_eval_metrics), + 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA) + } + + + Args: + tensor_dict (dictionary) : Dictionary with keys + with_opt_vars (bool) : Whether or not to save the optimizer state as well (this info will be part of the tensor dict in this case - i.e. tensor_dict = {**model_state, **opt_state}) + kwargs : unused + + Returns: + epoch + """ + # TODO: For now leaving the lr_scheduler_state_dict unchanged (this may be best though) + # TODO: Do we want to test this for 'RESET', 'CONTINUE_GLOBAL'? + + # get device for correct placement of tensors + device = self.device + + checkpoint_dict = self.load_checkpoint(map_location=device) + epoch = checkpoint_dict['epoch'] + new_state = {} + # grabbing keys from the checkpoint state dict, poping from the tensor_dict + # Brandon DEBUGGING + seen_keys = [] + for k in checkpoint_dict['state_dict']: + if k not in seen_keys: + seen_keys.append(k) + else: + raise ValueError(f"\nKey {k} apears at least twice!!!!/n") + new_state[k] = torch.from_numpy(tensor_dict[k].copy()).to(device) + checkpoint_dict['state_dict'] = new_state + + if with_opt_vars: + # see if there is state to restore first + if tensor_dict.pop('__opt_state_needed') == 'true': + checkpoint_dict = self._set_optimizer_state(derived_opt_state_dict=tensor_dict, + checkpoint_dict=checkpoint_dict) + self.save_checkpoint(checkpoint_dict) + + # FIXME: this should be unnecessary now + # we may want to know epoch so that we can properly tell the training script to what epoch to train (NNUnet V1 only supports training with a max_num_epochs setting) + return epoch + + + def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): + # TODO: Figure out the right name to use for this method and the default assigner + """Perform training for a specified number of epochs.""" + + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + # 2. Train function existing externally + # Some todo inside function below + # TODO: test for off-by-one error + # TODO: we need to disable validation if possible, and separately call validation + + # FIXME: we need to understand how to use round_num instead of current_epoch + # this will matter in straggler handling cases + # TODO: Should we put this in a separate process? + train_nnunet(epochs=epochs, current_epoch=current_epoch, task=self.data_loader.get_task_name()) + + # 3. Load metrics from checkpoint + (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkopint each epoch, so we select the most recent epoch + metrics = {'train_loss': all_tr_losses[-1], + 'val_eval': all_val_eval_metrics[-1]} + + return self.convert_results_to_tensorkeys(col_name, round_num, metrics) + + + + def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + """ + Run the trained model on validation data; report results. + + Parameters + ---------- + input_tensor_dict : either the last aggregated or locally trained model + + Returns + ------- + output_tensor_dict : {TensorKey: nparray} (these correspond to acc, + precision, f1_score, etc.) + """ + + raise NotImplementedError() + + """ - TBD - for now commenting out + + self.rebuild_model(round_num, input_tensor_dict, validation=True) + + # 1. Save model in native format + self.save_native(self.mlcube_model_in_path) + + # 2. Call MLCube validate task + platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) + task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') + proc = subprocess.run(["mlcube_docker", + "run", + "--mlcube={}".format(self.mlcube_dir), + "--platform={}".format(platform_yaml), + "--task={}".format(task_yaml)]) + + # 3. Load any metrics + metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) + + # set the validation data size + sample_count = int(metrics.pop(self.evaluation_sample_count_key)) + self.data_loader.set_valid_data_size(sample_count) + + # 4. Convert to tensorkeys + + origin = col_name + suffix = 'validate' + if kwargs['apply'] == 'local': + suffix += '_local' + else: + suffix += '_agg' + tags = ('metric', suffix) + output_tensor_dict = { + TensorKey( + metric_name, origin, round_num, True, tags + ): np.array(metrics[metric_name]) + for metric_name in metrics + } + + return output_tensor_dict, {} + + """ + + + def load_metrics(self, filepath): + """ + Load metrics from file on disk + """ + raise NotImplementedError() + """ + with open(filepath) as json_file: + metrics = json.load(json_file) + return metrics + """ \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py new file mode 100644 index 000000000..f1d167ae1 --- /dev/null +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -0,0 +1,321 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? + +""" +# TODO: Clean up imports + +import os +import shutil +from copy import deepcopy + +import numpy as np +import torch + +from openfl.utilities import TensorKey +from openfl.utilities.split import split_tensor_dict_for_holdouts + +from openfl.federated.task.runner import TaskRunner +from openfl.federated.task.runner_pt_utils import rebuild_model_util, derive_opt_state_dict, expand_derived_opt_state_dict +from openfl.federated.task.runner_pt_utils import initialize_tensorkeys_for_functions_util, to_cpu_numpy + + +class PyTorchCheckpointTaskRunner(TaskRunner): + """An abstract class for PyTorch model based Tasks, where training, validation etc. are processes that + pull model state from a PyTorch checkpoint.""" + + def __init__(self, + device = 'cuda', + gpu_num_string = '0', + checkpoint_path_initial = None, + checkpoint_path_save = None, + checkpoint_path_load = None, + **kwargs): + """Initialize. + + Args: + device(str) : Device ('cpu' or 'cuda') to be used for training and validation script computations. + checkpoint_path_initial(str): Path to the model checkpoint that will be used to initialize this object and copied to the 'write' path to start. + checkpoint_path_save(str) : Path to the model checkpoint that will be saved and passed into the training function. + checkpoint_path_load(str) : Path to the model checkpoint that will be loaded. It is also the output file path for the training function. + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + TODO: + """ + super().__init__(**kwargs) + + self.checkpoint_path_initial = checkpoint_path_initial + self.checkpoint_path_save = checkpoint_path_save + self.checkpoint_path_load = checkpoint_path_load + self.gpu_num_string = gpu_num_string + + # TODO: Understand why "weights-only" + + # TODO: Both 'CONTINUE_GLOBAL' and 'RESET' could be suported here too (currently RESET throws an exception related to a + # missmatch in size coming from the momentum buffer and other stuff either in the model or optimizer) + self.opt_treatment = 'CONTINUE_LOCAL' + + if device not in ['cpu', 'cuda']: + raise ValueError("Device argument must be 'cpu' or 'cuda', but {device} was used instead.") + self.device = device + + self.training_round_completed = False + + # enable GPUs if appropriate + if self.device == 'cuda' and not self.gpu_num_string: + raise ValueError(f"If device is 'cuda' then gpu_num must be set rather than allowing to be the default None.") + else: + os.environ['CUDA_VISIBLE_DEVICES']= self.gpu_num_string + + self.required_tensorkeys_for_function = {} + self.initialize_tensorkeys_for_functions() + + # overwrite attribute to account for one optimizer param (in every + # child model that does not overwrite get and set tensordict) that is + # not a numpy array + self.tensor_dict_split_fn_kwargs.update({ + 'holdout_tensor_names': ['__opt_state_needed'] + }) + + # Initialize model + self.replace_checkpoint(self.checkpoint_path_initial) + + + def load_checkpoint(self, checkpoint_path=None, map_location=None): + """ + Function used to load checkpoint from disk. + """ + if not checkpoint_path: + checkpoint_path = self.checkpoint_path_load + checkpoint_dict = torch.load(checkpoint_path, map_location=map_location) + return checkpoint_dict + + def save_checkpoint(self, checkpoint_dict): + """ + Function to save checkpoint to disk. + """ + torch.save(checkpoint_dict, self.checkpoint_path_save) + + # defining some class methods using some util functions imported above + + def rebuild_model(self, input_tensor_dict, **kwargs): + rebuild_model_util(runner_class=self, input_tensor_dict=input_tensor_dict, **kwargs) + + def initialize_tensorkeys_for_functions(self, **kwargs): + initialize_tensorkeys_for_functions_util(runner_class=self, **kwargs) + + def get_required_tensorkeys_for_function(self, func_name, **kwargs): + """ + Get the required tensors for specified function that could be called \ + as part of a task. By default, this is just all of the layers and \ + optimizer of the model. + + Args: + func_name + + Returns: + list : [TensorKey] + """ + if func_name == 'validate': + local_model = 'apply=' + str(kwargs['apply']) + return self.required_tensorkeys_for_function[func_name][local_model] + else: + return self.required_tensorkeys_for_function[func_name] + + def reset_opt_vars(self): + current_checkpoint_dict = self.load_checkpoint() + initial_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_initial) + derived_opt_state_dict = self._get_optimizer_state(checkpoint_dict=initial_checkpoint_dict) + self._set_optimizer_state(derived_opt_state_dict=derived_opt_state_dict, + checkpoint_dict=current_checkpoint_dict) + + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + + Args: + tensor_dict: The tensor dictionary + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + """ + return self.write_tensors_into_checkpoint(tensor_dict=tensor_dict, with_opt_vars=with_opt_vars) + + def replace_checkpoint(self, path_to_replacement): + checkpoint_dict = self.load_checkpoint(checkpoint_path=path_to_replacement) + self.save_checkpoint(checkpoint_dict) + # shutil.copyfile(src=path_to_replacement, dst=self.checkpoint_path_save) + + def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): + raise NotImplementedError + + def get_tensor_dict(self, with_opt_vars=False): + """Return the tensor dictionary. + + Args: + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + + Returns: + dict: Tensor dictionary {**dict, **optimizer_dict} + + """ + return self.read_tensors_from_checkpoint(with_opt_vars=with_opt_vars) + + def read_tensors_from_checkpoint(self, with_opt_vars): + """Return a tensor dictionary interpreted from a checkpoint. + + Args: + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False) + + Returns: + dict: Tensor dictionary {**dict, **optimizer_dict} + + """ + checkpoint_dict = self.load_checkpoint() + state = to_cpu_numpy(checkpoint_dict['state_dict']) + if with_opt_vars: + opt_state = self._get_optimizer_state(checkpoint_dict=checkpoint_dict) + state = {**state, **opt_state} + return state + + def _get_weights_names(self, with_opt_vars=False): + """ + Gets model and potentially optimizer state dict key names + args: + with_opt_vars(bool) : Wether or not to get the optimizer key names + """ + state = self.get_tensor_dict(with_opt_vars=with_opt_vars) + return state.keys() + + def _set_optimizer_state(self, derived_opt_state_dict, checkpoint_dict): + """Set the optimizer state. + # TODO: Refactor this, we will sparate the custom aspect of the checkpoint dict from the more general code + + Args: + derived_opt_state_dict(bool) : flattened optimizer state dict + checkpoint_dict(dict) : checkpoint dictionary + + """ + self._write_optimizer_state_into_checkpoint(derived_opt_state_dict=derived_opt_state_dict, + checkpoint_dict=checkpoint_dict, + checkpoint_path=self.checkpoint_out_path) + + def _write_optimizer_state_into_checkpoint(self, derived_opt_state_dict, checkpoint_dict, checkpoint_path): + """Write the optimizer state contained within the derived_opt_state_dict into the checkpoint_dict, + keeping some settings already contained within that checkpoint file the same, then write the resulting + checkpoint back to the checkpoint path. + TODO: Refactor this, we will separate the custom aspect of the checkpoint dict from the more general code + + Args: + derived_opt_state_dict(bool) : flattened optimizer state dict + checkpoint_dir(path) : Path to the checkpoint file + + """ + temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, device=self.device) + # Note: The expansion above only populates the 'params' key of each param group under opt_state_dict['param_groups'] + # Therefore the default values under the additional keys such as: 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'maximize', 'foreach', 'differentiable' + # need to be held over from the their initial values. + # FIXME: Figure out whether or not this breaks learning rate scheduling and the like. + + # Letting default values (everything under temp_state_dict['param_groups'] except the 'params' key) + # stay unchanged (these are not contained in the temp_state_dict) + # Assuming therefore that the optimizer.defaults (which hold this same info) are not changing over course of training. + # We only modify the 'state' key value pairs otherwise + for group_idx, group in enumerate(temp_state_dict['param_groups']): + checkpoint_dict['optimizer_state_dict']['param_groups'][group_idx]['params'] = group['params'] + checkpoint_dict['optimizer_state_dict']['state'] = temp_state_dict['state'] + + torch.save(checkpoint_dict, checkpoint_path) + + def _get_optimizer_state(self, checkpoint_dict): + """Get the optimizer state. + Args: + checkpoint_path(str) : path to the checkpoint + """ + return self._read_opt_state_from_checkpoint(checkpoint_dict) + + + def _read_opt_state_from_checkpoint(self, checkpoint_dict): + """Read the optimizer state from the checkpoint dict and put in tensor dict format. + # TODO: Refactor this, we will sparate the custom aspect of the checkpoint dict from the more general code + """ + + opt_state_dict = deepcopy(checkpoint_dict['optimizer_state_dict']) + + # Optimizer state might not have some parts representing frozen parameters + # So we do not synchronize them + param_keys_with_state = set(opt_state_dict['state'].keys()) + for group in opt_state_dict['param_groups']: + local_param_set = set(group['params']) + params_to_sync = local_param_set & param_keys_with_state + group['params'] = sorted(params_to_sync) + derived_opt_state_dict = derive_opt_state_dict(opt_state_dict) + + return derived_opt_state_dict + + + def convert_results_to_tensorkeys(self, col_name, round_num, metrics): + # 5. Convert to tensorkeys + + # output metric tensors (scalar) + origin = col_name + tags = ('trained',) + output_metric_dict = { + TensorKey( + metric_name, origin, round_num, True, ('metric',) + ): np.array( + metrics[metric_name] + ) for metric_name in metrics} + + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict(with_opt_vars=True) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, + tensor_dict=output_model_dict, + **self.tensor_dict_split_fn_kwargs) + + # create global tensorkeys + global_tensorkey_model_dict = { + TensorKey( + tensor_name, origin, round_num, False, tags + ): nparray for tensor_name, nparray in global_model_dict.items() + } + # create tensorkeys that should stay local + local_tensorkey_model_dict = { + TensorKey( + tensor_name, origin, round_num, False, tags + ): nparray for tensor_name, nparray in local_model_dict.items() + } + # the train/validate aggregated function of the next round will look + # for the updated model parameters. + # this ensures they will be resolved locally + next_local_tensorkey_model_dict = { + TensorKey( + tensor_name, origin, round_num + 1, False, ('model',) + ): nparray for tensor_name, nparray in local_model_dict.items() + } + + global_tensor_dict = { + **output_metric_dict, + **global_tensorkey_model_dict + } + local_tensor_dict = { + **local_tensorkey_model_dict, + **next_local_tensorkey_model_dict + } + + # update the required tensors if they need to be pulled from the + # aggregator + # TODO this logic can break if different collaborators have different + # roles between rounds. + # for example, if a collaborator only performs validation in the first + # round but training in the second, it has no way of knowing the + # optimizer state tensor names to request from the aggregator + # because these are only created after training occurs. + # A work around could involve doing a single epoch of training + # on random data to get the optimizer names, and then throwing away + # the model. + if self.opt_treatment == 'CONTINUE_GLOBAL': + self.initialize_tensorkeys_for_functions(with_opt_vars=True) + + return global_tensor_dict, local_tensor_dict diff --git a/examples/fl_post/fl/project/utils.py b/examples/fl_post/fl/project/utils.py new file mode 100644 index 000000000..e6653ae76 --- /dev/null +++ b/examples/fl_post/fl/project/utils.py @@ -0,0 +1,132 @@ +import yaml +import os +import shutil + + +def create_workspace(fl_workspace): + plan_folder = os.path.join(fl_workspace, "plan") + workspace_config = os.path.join(fl_workspace, ".workspace") + defaults_file = os.path.join(plan_folder, "defaults") + + os.makedirs(plan_folder, exist_ok=True) + with open(defaults_file, "w") as f: + f.write("../../workspace/plan/defaults\n\n") + with open(workspace_config, "w") as f: + f.write("current_plan_name: default\n\n") + + +def get_aggregator_fqdn(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return plan["network"]["settings"]["agg_addr"].lower() + + +def get_collaborator_cn(): + # TODO: check if there is a way this can cause a collision/race condition + # TODO: from inside the file + return os.environ["MEDPERF_PARTICIPANT_LABEL"] + + +def get_weights_path(fl_workspace): + plan_path = os.path.join(fl_workspace, "plan", "plan.yaml") + plan = yaml.safe_load(open(plan_path)) + return { + "init": plan["aggregator"]["settings"]["init_state_path"], + "best": plan["aggregator"]["settings"]["best_state_path"], + "last": plan["aggregator"]["settings"]["last_state_path"], + } + + +def prepare_plan(plan_path, fl_workspace): + target_plan_folder = os.path.join(fl_workspace, "plan") + # TODO: permissions + os.makedirs(target_plan_folder, exist_ok=True) + + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") + shutil.copyfile(plan_path, target_plan_file) + + +def prepare_cols_list(collaborators_file, fl_workspace): + with open(collaborators_file) as f: + cols_dict = yaml.safe_load(f) + cn_different = False + for col_label in cols_dict.keys(): + cn = cols_dict[col_label] + if cn != col_label: + cn_different = True + if not cn_different: + # quick hack to support old and new openfl versions + cols_dict = list(cols_dict.keys()) + + target_plan_folder = os.path.join(fl_workspace, "plan") + # TODO: permissions + os.makedirs(target_plan_folder, exist_ok=True) + target_plan_file = os.path.join(target_plan_folder, "cols.yaml") + with open(target_plan_file, "w") as f: + yaml.dump({"collaborators": cols_dict}, f) + + +def prepare_init_weights(input_weights, fl_workspace): + error_msg = f"{input_weights} should contain only one file: *.pbuf" + + files = os.listdir(input_weights) + file = files[0] # TODO: this may cause failure in MAC OS + if len(files) != 1 or not file.endswith(".pbuf"): + raise RuntimeError(error_msg) + + file = os.path.join(input_weights, file) + + target_weights_subpath = get_weights_path(fl_workspace)["init"] + target_weights_path = os.path.join(fl_workspace, target_weights_subpath) + target_weights_folder = os.path.dirname(target_weights_path) + # TODO: permissions + os.makedirs(target_weights_folder, exist_ok=True) + os.symlink(file, target_weights_path) + + +def prepare_node_cert( + node_cert_folder, target_cert_folder_name, target_cert_name, fl_workspace +): + error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" + + files = os.listdir(node_cert_folder) + file_extensions = [file.split(".")[-1] for file in files] + if len(files) != 2 or sorted(file_extensions) != ["crt", "key"]: + raise RuntimeError(error_msg) + + if files[0].endswith(".crt") and files[1].endswith(".key"): + cert_file = files[0] + key_file = files[1] + else: + key_file = files[0] + cert_file = files[1] + + key_file = os.path.join(node_cert_folder, key_file) + cert_file = os.path.join(node_cert_folder, cert_file) + + target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) + # TODO: permissions + os.makedirs(target_cert_folder, exist_ok=True) + target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") + target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") + + os.symlink(key_file, target_key_file) + os.symlink(cert_file, target_cert_file) + + +def prepare_ca_cert(ca_cert_folder, fl_workspace): + error_msg = f"{ca_cert_folder} should contain only one file: *.crt" + + files = os.listdir(ca_cert_folder) + file = files[0] + if len(files) != 1 or not file.endswith(".crt"): + raise RuntimeError(error_msg) + + file = os.path.join(ca_cert_folder, file) + + target_ca_cert_folder = os.path.join(fl_workspace, "cert") + # TODO: permissions + os.makedirs(target_ca_cert_folder, exist_ok=True) + target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") + + os.symlink(file, target_ca_cert_file) \ No newline at end of file diff --git a/examples/fl_post/fl/setup_clean.sh b/examples/fl_post/fl/setup_clean.sh new file mode 100644 index 000000000..9f9242024 --- /dev/null +++ b/examples/fl_post/fl/setup_clean.sh @@ -0,0 +1,5 @@ +rm -rf ./mlcube_agg +rm -rf ./mlcube_col1 +rm -rf ./mlcube_col2 +rm -rf ./mlcube_col3 +rm -rf ./ca diff --git a/examples/fl_post/fl/setup_test.sh b/examples/fl_post/fl/setup_test.sh new file mode 100644 index 000000000..75a3d68f5 --- /dev/null +++ b/examples/fl_post/fl/setup_test.sh @@ -0,0 +1,124 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +# root ca +openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 +openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +# col1 +sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf +cd mlcube_col1/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col2 +sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf +cd mlcube_col2/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf + cd mlcube_col3/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ +fi + +# agg +sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf +cd mlcube_agg/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL,$COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL,$COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL,$COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace + +# weights download +cd mlcube_agg/workspace/ +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz +tar -xf init_weights_miccai.tar.gz +rm init_weights_miccai.tar.gz +cd ../../.. diff --git a/examples/fl_post/fl/sync.sh b/examples/fl_post/fl/sync.sh new file mode 100755 index 000000000..a5375ce54 --- /dev/null +++ b/examples/fl_post/fl/sync.sh @@ -0,0 +1,6 @@ +cp mlcube/workspace/training_config.yaml mlcube_agg/workspace/training_config.yaml + +cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col2/mlcube.yaml +cp mlcube/mlcube.yaml mlcube_col3/mlcube.yaml diff --git a/examples/fl_post/fl/test_agg.sh b/examples/fl_post/fl/test_agg.sh new file mode 100755 index 000000000..f9bb9faec --- /dev/null +++ b/examples/fl_post/fl/test_agg.sh @@ -0,0 +1,27 @@ +# generate plan and copy it to each node +medperf mlcube run --mlcube ./mlcube_agg --task generate_plan +mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +rm -r ./mlcube_agg/workspace/plan +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace + +medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273 + +# medperf --gpus="device=0" mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/additional_files/init_nnunet:/mlcube_io5:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/logs:/mlcube_io6 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --init_nnunet_directory=/mlcube_io5 --output_logs=/mlcube_io6 +exit +# Run nodes +AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" + +gnome-terminal -- bash -c "$AGG; bash" +gnome-terminal -- bash -c "$COL1; bash" +gnome-terminal -- bash -c "$COL2; bash" +gnome-terminal -- bash -c "$COL3; bash" + +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/test_col1.sh b/examples/fl_post/fl/test_col1.sh new file mode 100755 index 000000000..fc47280f6 --- /dev/null +++ b/examples/fl_post/fl/test_col1.sh @@ -0,0 +1,25 @@ +# generate plan and copy it to each node +# medperf mlcube run --mlcube ./mlcube_agg --task generate_plan +# mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +# rm -r ./mlcube_agg/workspace/plan +# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace +# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace +# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace + +medperf --gpus="device=0" mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/additional_files/init_nnunet:/mlcube_io5:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/logs:/mlcube_io6 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --init_nnunet_directory=/mlcube_io5 --output_logs=/mlcube_io6 +exit +# Run nodes +AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" + +gnome-terminal -- bash -c "$AGG; bash" +gnome-terminal -- bash -c "$COL1; bash" +gnome-terminal -- bash -c "$COL2; bash" +gnome-terminal -- bash -c "$COL3; bash" + +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/test_col2.sh b/examples/fl_post/fl/test_col2.sh new file mode 100755 index 000000000..1b9fae628 --- /dev/null +++ b/examples/fl_post/fl/test_col2.sh @@ -0,0 +1,25 @@ +# generate plan and copy it to each node +# medperf mlcube run --mlcube ./mlcube_agg --task generate_plan +# mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +# rm -r ./mlcube_agg/workspace/plan +# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace +# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace +# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace + +medperf --gpus="device=1" mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/additional_files/init_nnunet:/mlcube_io5:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/logs:/mlcube_io6 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --init_nnunet_directory=/mlcube_io5 --output_logs=/mlcube_io6 +exit +# Run nodes +AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" + +gnome-terminal -- bash -c "$AGG; bash" +gnome-terminal -- bash -c "$COL1; bash" +gnome-terminal -- bash -c "$COL2; bash" +gnome-terminal -- bash -c "$COL3; bash" + +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 From 5c83c5b36747c1cc46191e2caf5b00c3cad73b6c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 12 Jun 2024 19:46:59 +0200 Subject: [PATCH 077/242] use same setup files as fl/fl --- examples/fl_post/fl/.test.sh.swp | Bin 12288 -> 0 bytes examples/fl_post/fl/build.sh | 6 + examples/fl_post/fl/clean.sh | 0 examples/fl_post/fl/csr.conf | 24 ++-- examples/fl_post/fl/mlcube/mlcube.yaml | 2 +- examples/fl_post/fl/setup_test.sh | 65 +++------- examples/fl_post/fl/setup_test_no_docker.sh | 124 +++++++++++++++++++ examples/fl_post/fl/sync.sh | 0 examples/fl_post/fl/{test_agg.sh => test.sh} | 6 - examples/fl_post/fl/test_col1.sh | 25 ---- examples/fl_post/fl/test_col2.sh | 25 ---- 11 files changed, 164 insertions(+), 113 deletions(-) delete mode 100644 examples/fl_post/fl/.test.sh.swp mode change 100755 => 100644 examples/fl_post/fl/build.sh mode change 100755 => 100644 examples/fl_post/fl/clean.sh create mode 100644 examples/fl_post/fl/setup_test_no_docker.sh mode change 100755 => 100644 examples/fl_post/fl/sync.sh rename examples/fl_post/fl/{test_agg.sh => test.sh} (59%) mode change 100755 => 100644 delete mode 100755 examples/fl_post/fl/test_col1.sh delete mode 100755 examples/fl_post/fl/test_col2.sh diff --git a/examples/fl_post/fl/.test.sh.swp b/examples/fl_post/fl/.test.sh.swp deleted file mode 100644 index 327381c6955512a7f275e4252a493383572270b1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2L2uJA6vsV;5J(IVh`VW#I3;P)O-w_R7&}ZzgQ1EcaSCqgCS^%%+3rR!Ai)W7 z=D?jVuulL&;xiyV3m2Z-q{{|$K~oQ)7wIW=-n0Mlug}k_>Tc`F{Rec#^$A>82-!Sv z`#bK(bMHE5Nv0E?rd(OqtV>Jf#B$ggk1-?gWjUZ14~&=bp5&_UBuum1<1bhyQ?9*! z8cLxJ+&#m!adk3dM06nmBv6aM(Da>^#hrHGw`_c~af4pFx>Ktw1|b0?fCP{L5{EBdvJ@i3(8!?)QxWURz{{jbdm`H9imIV~ZJzJZjqR%&@$ZD<;Uu2OMB3=0B6VaY!@wA#63%E zwe$g~RVC)2GGSlQ@#mv6{U$3?!F;5Otf&aJ=OR=QSelD~XEU>+vP_+U87HJsBPj;C zph_&+yKz!nP&3Z2 z6Suy(S%wk&4r`6#FtVKql#pWG6%Dss1F$~qx;9MjmL-M%Zeg~o3$tMrg8B-!temJ2 z{G#C3SFm;MM1>>8r9psLO#v3i<@nk#D4g&>@t7H*sPl*}w}RX42B96OY<5X%`1vU5 z6fp!>D$rrAoCetM;4tE^`v#}zTd9piLBUnCUYlflv)q+!U^=xUb0@VjvI0BiIakcU f_6Hjb>>y>E+BRi*k46HvLRh~>P|l)cjGz1jYQDe5 diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh old mode 100755 new mode 100644 index d56304274..67cda94a7 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -1 +1,7 @@ +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl_post/fl/clean.sh b/examples/fl_post/fl/clean.sh old mode 100755 new mode 100644 diff --git a/examples/fl_post/fl/csr.conf b/examples/fl_post/fl/csr.conf index 5ac85ae39..c3b2d0f0c 100644 --- a/examples/fl_post/fl/csr.conf +++ b/examples/fl_post/fl/csr.conf @@ -3,21 +3,29 @@ default_bits = 3072 prompt = no default_md = sha384 distinguished_name = req_distinguished_name -req_extensions = req_ext [ req_distinguished_name ] -commonName = spr-gpu01.jf.intel.com - -[ req_ext ] -basicConstraints = critical,CA:FALSE -keyUsage = critical,digitalSignature,keyEncipherment -subjectAltName = @alt_names +commonName = hasan-hp-zbook-15-g3.home [ alt_names ] -DNS.1 = spr-gpu01.jf.intel.com +DNS.1 = hasan-hp-zbook-15-g3.home [ v3_client ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names extendedKeyUsage = critical,clientAuth [ v3_server ] +basicConstraints = critical,CA:FALSE +keyUsage = critical,digitalSignature,keyEncipherment +subjectAltName = @alt_names extendedKeyUsage = critical,serverAuth + +[ v3_client_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names + +[ v3_server_crt ] +basicConstraints = critical,CA:FALSE +subjectAltName = @alt_names diff --git a/examples/fl_post/fl/mlcube/mlcube.yaml b/examples/fl_post/fl/mlcube/mlcube.yaml index 39ecc21a9..b13dc4eb8 100644 --- a/examples/fl_post/fl/mlcube/mlcube.yaml +++ b/examples/fl_post/fl/mlcube/mlcube.yaml @@ -9,7 +9,7 @@ platform: docker: gpu_args: "--shm-size 12g" # Image name - image: msheller/mlcube_testing:nnunet_fl_test + image: local/tmp:0.0.0 # Docker build context relative to $MLCUBE_ROOT. Default is `build`. build_context: "../project" # Docker file name within docker build context, default is `Dockerfile`. diff --git a/examples/fl_post/fl/setup_test.sh b/examples/fl_post/fl/setup_test.sh index 75a3d68f5..542dd7164 100644 --- a/examples/fl_post/fl/setup_test.sh +++ b/examples/fl_post/fl/setup_test.sh @@ -31,72 +31,41 @@ mkdir ./ca HOSTNAME_=$(hostname -A | cut -d " " -f 1) -# root ca -openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 -openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ - -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" +medperf mlcube run --mlcube ../mock_cert/mlcube --task trust +mv ../mock_cert/mlcube/workspace/pki_assets/* ./ca # col1 -sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf -cd mlcube_col1/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd ../../../ +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL1_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col1/workspace/node_cert +cp -r ./ca/* ./mlcube_col1/workspace/ca_cert # col2 -sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf -cd mlcube_col2/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd ../../../ +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL2_CN +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col2/workspace/node_cert +cp -r ./ca/* ./mlcube_col2/workspace/ca_cert # col3 if ${TWO_COL_SAME_CERT}; then cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert else - sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf - sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf - cd mlcube_col3/workspace/node_cert - openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 - openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client - openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 - rm csr.csr - cp ../../../ca/root.crt ../ca_cert/ - cd ../../../ + medperf mlcube run --mlcube ../mock_cert/mlcube --task get_client_cert -e MEDPERF_INPUT_CN=$COL3_CN + mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_col3/workspace/node_cert + cp -r ./ca/* ./mlcube_col3/workspace/ca_cert fi -# agg -sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf -cd mlcube_agg/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd ../../../ +medperf mlcube run --mlcube ../mock_cert/mlcube --task get_server_cert -e MEDPERF_INPUT_CN=$HOSTNAME_ +mv ../mock_cert/mlcube/workspace/pki_assets/* ./mlcube_agg/workspace/node_cert +cp -r ./ca/* ./mlcube_agg/workspace/ca_cert # aggregator_config echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml # cols file -echo "$COL1_LABEL,$COL1_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL2_LABEL,$COL2_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL3_LABEL,$COL3_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml # data download cd mlcube_col1/workspace/ diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh new file mode 100644 index 000000000..879e84ced --- /dev/null +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -0,0 +1,124 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org2@example.com" # in this case this var is not used actually. it's OK +fi + +cp -r ./mlcube ./mlcube_agg +cp -r ./mlcube ./mlcube_col1 +cp -r ./mlcube ./mlcube_col2 +cp -r ./mlcube ./mlcube_col3 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) + +# root ca +openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 +openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +# col1 +sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf +cd mlcube_col1/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col2 +sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf +cd mlcube_col2/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# col3 +if ${TWO_COL_SAME_CERT}; then + cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf + cd mlcube_col3/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd ../../../ +fi + +# agg +sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf +cd mlcube_agg/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_server_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd ../../../ + +# aggregator_config +echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# data download +cd mlcube_col1/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz +tar -xf col1_prepared.tar.gz +rm col1_prepared.tar.gz +cd ../.. + +cd mlcube_col2/workspace/ +wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz +tar -xf col2_prepared.tar.gz +rm col2_prepared.tar.gz +cd ../.. + +cp -r mlcube_col2/workspace/data mlcube_col3/workspace +cp -r mlcube_col2/workspace/labels mlcube_col3/workspace + +# weights download +cd mlcube_agg/workspace/ +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz +tar -xf init_weights_miccai.tar.gz +rm init_weights_miccai.tar.gz +cd ../../.. diff --git a/examples/fl_post/fl/sync.sh b/examples/fl_post/fl/sync.sh old mode 100755 new mode 100644 diff --git a/examples/fl_post/fl/test_agg.sh b/examples/fl_post/fl/test.sh old mode 100755 new mode 100644 similarity index 59% rename from examples/fl_post/fl/test_agg.sh rename to examples/fl_post/fl/test.sh index f9bb9faec..3a154936a --- a/examples/fl_post/fl/test_agg.sh +++ b/examples/fl_post/fl/test.sh @@ -6,12 +6,6 @@ cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace -medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273 - -# medperf --gpus="device=0" mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com -# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/additional_files/init_nnunet:/mlcube_io5:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/logs:/mlcube_io6 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 -# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --init_nnunet_directory=/mlcube_io5 --output_logs=/mlcube_io6 -exit # Run nodes AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" diff --git a/examples/fl_post/fl/test_col1.sh b/examples/fl_post/fl/test_col1.sh deleted file mode 100755 index fc47280f6..000000000 --- a/examples/fl_post/fl/test_col1.sh +++ /dev/null @@ -1,25 +0,0 @@ -# generate plan and copy it to each node -# medperf mlcube run --mlcube ./mlcube_agg --task generate_plan -# mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace -# rm -r ./mlcube_agg/workspace/plan -# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace -# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace -# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace - -medperf --gpus="device=0" mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com -# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/additional_files/init_nnunet:/mlcube_io5:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/logs:/mlcube_io6 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 -# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --init_nnunet_directory=/mlcube_io5 --output_logs=/mlcube_io6 -exit -# Run nodes -AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" - -gnome-terminal -- bash -c "$AGG; bash" -gnome-terminal -- bash -c "$COL1; bash" -gnome-terminal -- bash -c "$COL2; bash" -gnome-terminal -- bash -c "$COL3; bash" - -# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 -# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/test_col2.sh b/examples/fl_post/fl/test_col2.sh deleted file mode 100755 index 1b9fae628..000000000 --- a/examples/fl_post/fl/test_col2.sh +++ /dev/null @@ -1,25 +0,0 @@ -# generate plan and copy it to each node -# medperf mlcube run --mlcube ./mlcube_agg --task generate_plan -# mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace -# rm -r ./mlcube_agg/workspace/plan -# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace -# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace -# cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace - -medperf --gpus="device=1" mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com -# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/additional_files/init_nnunet:/mlcube_io5:ro --volume /home/msheller/git/medperf-hasan/examples/fl_post/fl/mlcube_col1/workspace/logs:/mlcube_io6 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 -# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --init_nnunet_directory=/mlcube_io5 --output_logs=/mlcube_io6 -exit -# Run nodes -AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" - -gnome-terminal -- bash -c "$AGG; bash" -gnome-terminal -- bash -c "$COL1; bash" -gnome-terminal -- bash -c "$COL2; bash" -gnome-terminal -- bash -c "$COL3; bash" - -# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 -# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 From b44ab75cd1133c97407c569f977a45277ad04057 Mon Sep 17 00:00:00 2001 From: Micah Sheller Date: Thu, 13 Jun 2024 09:49:55 -0700 Subject: [PATCH 078/242] Added missing utils file and changed import path to read from the src folder --- .../fl_post/fl/project/src/runner_pt_chkpt.py | 6 +- .../fl_post/fl/project/src/runner_pt_utils.py | 278 ++++++++++++++++++ 2 files changed, 281 insertions(+), 3 deletions(-) create mode 100644 examples/fl_post/fl/project/src/runner_pt_utils.py diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index f1d167ae1..6ab7851b9 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -1,5 +1,5 @@ # Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0 """ Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? @@ -18,8 +18,8 @@ from openfl.utilities.split import split_tensor_dict_for_holdouts from openfl.federated.task.runner import TaskRunner -from openfl.federated.task.runner_pt_utils import rebuild_model_util, derive_opt_state_dict, expand_derived_opt_state_dict -from openfl.federated.task.runner_pt_utils import initialize_tensorkeys_for_functions_util, to_cpu_numpy +from .runner_pt_utils import rebuild_model_util, derive_opt_state_dict, expand_derived_opt_state_dict +from .runner_pt_utils import initialize_tensorkeys_for_functions_util, to_cpu_numpy class PyTorchCheckpointTaskRunner(TaskRunner): diff --git a/examples/fl_post/fl/project/src/runner_pt_utils.py b/examples/fl_post/fl/project/src/runner_pt_utils.py new file mode 100644 index 000000000..28fa33de2 --- /dev/null +++ b/examples/fl_post/fl/project/src/runner_pt_utils.py @@ -0,0 +1,278 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities modeule for PyTorch related Task Runners""" + +# NOTE: this might want to be its own PR to openfl + +from copy import deepcopy +import torch as pt +import numpy as np + +from openfl.utilities.split import split_tensor_dict_for_holdouts +from openfl.utilities import TensorKey + + +def rebuild_model_util(runner_class, input_tensor_dict, testing_with_opt_setting=False, **kwargs): + """ + Parse tensor names and update weights of model. Assumes opt_treatement == CONTINUE_LOCAL, but + allows for writing in optimizer variables for testing purposes + + Returns: + None + """ + if testing_with_opt_setting: + with_opt_vars = True + else: + with_opt_vars = False + + runner_class.set_tensor_dict(input_tensor_dict, with_opt_vars=with_opt_vars) + + +def derive_opt_state_dict(opt_state_dict): + """Separate optimizer tensors from the tensor dictionary. + + Flattens the optimizer state dict so as to have key, value pairs with + values as numpy arrays. + The keys have sufficient info to restore opt_state_dict using + expand_derived_opt_state_dict. + + Args: + opt_state_dict: The optimizer state dictionary + + """ + derived_opt_state_dict = {} + + # Determine if state is needed for this optimizer. + if len(opt_state_dict['state']) == 0: + derived_opt_state_dict['__opt_state_needed'] = 'false' + return derived_opt_state_dict + + derived_opt_state_dict['__opt_state_needed'] = 'true' + + # Using one example state key, we collect keys for the corresponding + # dictionary value. + example_state_key = opt_state_dict['param_groups'][0]['params'][0] + example_state_subkeys = set( + opt_state_dict['state'][example_state_key].keys() + ) + + + + # We assume that the state collected for all params in all param groups is + # the same. + # We also assume that whether or not the associated values to these state + # subkeys is a tensor depends only on the subkey. + # Using assert statements to break the routine if these assumptions are + # incorrect. + for state_key in opt_state_dict['state'].keys(): + assert example_state_subkeys == set(opt_state_dict['state'][state_key].keys()) + for state_subkey in example_state_subkeys: + assert (isinstance( + opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor) + == isinstance( + opt_state_dict['state'][state_key][state_subkey], + pt.Tensor)) + + state_subkeys = list(opt_state_dict['state'][example_state_key].keys()) + + # Tags will record whether the value associated to the subkey is a + # tensor or not. + state_subkey_tags = [] + for state_subkey in state_subkeys: + if isinstance( + opt_state_dict['state'][example_state_key][state_subkey], + pt.Tensor + ): + state_subkey_tags.append('istensor') + else: + state_subkey_tags.append('') + state_subkeys_and_tags = list(zip(state_subkeys, state_subkey_tags)) + + # Forming the flattened dict, using a concatenation of group index, + # subindex, tag, and subkey inserted into the flattened dict key - + # needed for reconstruction. + nb_params_per_group = [] + for group_idx, group in enumerate(opt_state_dict['param_groups']): + for idx, param_id in enumerate(group['params']): + for subkey, tag in state_subkeys_and_tags: + if tag == 'istensor': + new_v = opt_state_dict['state'][param_id][ + subkey].cpu().numpy() + else: + new_v = np.array( + [opt_state_dict['state'][param_id][subkey]] + ) + derived_opt_state_dict[f'__opt_state_{group_idx}_{idx}_{tag}_{subkey}'] = new_v + nb_params_per_group.append(idx + 1) + # group lengths are also helpful for reconstructing + # original opt_state_dict structure + derived_opt_state_dict['__opt_group_lengths'] = np.array( + nb_params_per_group + ) + return derived_opt_state_dict + + +def expand_derived_opt_state_dict(derived_opt_state_dict, device): + """Expand the optimizer state dictionary. + + Takes a derived opt_state_dict and creates an opt_state_dict suitable as + input for load_state_dict for restoring optimizer state. + + Reconstructing state_subkeys_and_tags using the example key + prefix, "__opt_state_0_0_", certain to be present. + + Args: + derived_opt_state_dict: Optimizer state dictionary + + Returns: + dict: Optimizer state dictionary + """ + state_subkeys_and_tags = [] + for key in derived_opt_state_dict: + if key.startswith('__opt_state_0_0_'): + stripped_key = key[16:] + if stripped_key.startswith('istensor_'): + this_tag = 'istensor' + subkey = stripped_key[9:] + else: + this_tag = '' + subkey = stripped_key[1:] + state_subkeys_and_tags.append((subkey, this_tag)) + + opt_state_dict = {'param_groups': [], 'state': {}} + nb_params_per_group = list( + derived_opt_state_dict.pop('__opt_group_lengths').astype(np.int32) + ) + + # Construct the expanded dict. + for group_idx, nb_params in enumerate(nb_params_per_group): + these_group_ids = [f'{group_idx}_{idx}' for idx in range(nb_params)] + opt_state_dict['param_groups'].append({'params': these_group_ids}) + for this_id in these_group_ids: + opt_state_dict['state'][this_id] = {} + for subkey, tag in state_subkeys_and_tags: + flat_key = f'__opt_state_{this_id}_{tag}_{subkey}' + if tag == 'istensor': + new_v = pt.from_numpy(derived_opt_state_dict.pop(flat_key)) + else: + # Here (for currrently supported optimizers) the subkey + # should be 'step' and the length of array should be one. + assert subkey == 'step' + assert len(derived_opt_state_dict[flat_key]) == 1 + new_v = int(derived_opt_state_dict.pop(flat_key)) + opt_state_dict['state'][this_id][subkey] = new_v + + # sanity check that we did not miss any optimizer state (after removing __opt_state_needed) + derived_opt_state_dict.pop('__opt_state_needed') + if len(derived_opt_state_dict) != 0: + raise ValueError(f"Opt state should have been exausted, but we have left: {derived_opt_state_dict}") + + return opt_state_dict + + +def initialize_tensorkeys_for_functions_util(runner_class, with_opt_vars=False): + """Set the required tensors for all publicly accessible task methods. + + By default, this is just all of the layers and optimizer of the model. + Custom tensors should be added to this function. + + Args: + None + + Returns: + None + """ + # TODO there should be a way to programmatically iterate through + # all of the methods in the class and declare the tensors. + # For now this is done manually + + output_model_dict = runner_class.get_tensor_dict(with_opt_vars=with_opt_vars) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + runner_class.logger, output_model_dict, + **runner_class.tensor_dict_split_fn_kwargs + ) + if not with_opt_vars: + global_model_dict_val = global_model_dict + local_model_dict_val = local_model_dict + else: + output_model_dict = runner_class.get_tensor_dict(with_opt_vars=False) + global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( + runner_class.logger, + output_model_dict, + **runner_class.tensor_dict_split_fn_kwargs + ) + + runner_class.required_tensorkeys_for_function['train_batches'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + for tensor_name in global_model_dict] + runner_class.required_tensorkeys_for_function['train_batches'] += [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + for tensor_name in local_model_dict] + + runner_class.required_tensorkeys_for_function['train'] = [ + TensorKey( + tensor_name, 'GLOBAL', 0, False, ('model',) + ) for tensor_name in global_model_dict + ] + runner_class.required_tensorkeys_for_function['train'] += [ + TensorKey( + tensor_name, 'LOCAL', 0, False, ('model',) + ) for tensor_name in local_model_dict + ] + + # Validation may be performed on local or aggregated (global) model, + # so there is an extra lookup dimension for kwargs + runner_class.required_tensorkeys_for_function['validate'] = {} + # TODO This is not stateless. The optimizer will not be + runner_class.required_tensorkeys_for_function['validate']['apply=local'] = [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) + for tensor_name in { + **global_model_dict_val, + **local_model_dict_val + }] + runner_class.required_tensorkeys_for_function['validate']['apply=global'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) + for tensor_name in global_model_dict_val + ] + runner_class.required_tensorkeys_for_function['validate']['apply=global'] += [ + TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) + for tensor_name in local_model_dict_val + ] + + +def to_cpu_numpy(state): + """Send data to CPU as Numpy array. + + Args: + state + + """ + # deep copy so as to decouple from active model + state = deepcopy(state) + + for k, v in state.items(): + # When restoring, we currently assume all values are tensors. + if not pt.is_tensor(v): + raise ValueError('We do not currently support non-tensors ' + 'coming from model.state_dict()') + # get as a numpy array, making sure is on cpu + state[k] = v.cpu().numpy() + return state + + +class DummyDataLoader(): + def __init__(self, feature_shape, training_data_size, valid_data_size): + self.feature_shape = feature_shape + self.training_data_size = training_data_size + self.valid_data_size = valid_data_size + + def get_feature_shape(self): + return self.feature_shape + + def get_training_data_size(self): + return self.training_data_size + + def get_valid_data_size(self): + return self.valid_data_size From 76e11b4d6994f1dd003a2d827e5c5e2df4e1a409 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 14 Jun 2024 01:09:14 +0200 Subject: [PATCH 079/242] remove buggy unused imports --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index ecc3869eb..6857378b8 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -22,9 +22,6 @@ from openfl.utilities import TensorKey from openfl.utilities.split import split_tensor_dict_for_holdouts -from openfl.federated.task.runner_pt_utils import rebuild_model_util, derive_opt_state_dict, expand_derived_opt_state_dict -from openfl.federated.task.runner_pt_utils import initialize_tensorkeys_for_functions_util, to_cpu_numpy -from openfl.federated.task.nnunet_v1 import train_nnunet from .runner_pt_chkpt import PyTorchCheckpointTaskRunner from .nnunet_v1 import train_nnunet From 21bbd1d2ee846338dfb0175060e9560b1b472384 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 14 Jun 2024 12:10:22 +0200 Subject: [PATCH 080/242] modify build script --- examples/fl/fl/build.sh | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/fl/fl/build.sh b/examples/fl/fl/build.sh index 67cda94a7..9e0ea346f 100644 --- a/examples/fl/fl/build.sh +++ b/examples/fl/fl/build.sh @@ -1,7 +1,16 @@ -git clone https://github.com/securefederatedai/openfl.git -cd openfl -git checkout e6f3f5fd4462307b2c9431184190167aa43d962f -docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . -cd .. -rm -rf openfl +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/securefederatedai/openfl.git + cd openfl + git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always From 046ad750acaaaab5932d6b6fd0725ae2067f96b5 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 2 Jul 2024 17:25:32 +0000 Subject: [PATCH 081/242] sync fl examples --- examples/fl/fl/project/aggregator.py | 7 +++---- examples/fl_post/fl/build.sh | 21 +++++++++++++++------ examples/fl_post/fl/project/utils.py | 9 +-------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/examples/fl/fl/project/aggregator.py b/examples/fl/fl/project/aggregator.py index 4d190e5d1..c0bbeafa1 100644 --- a/examples/fl/fl/project/aggregator.py +++ b/examples/fl/fl/project/aggregator.py @@ -6,7 +6,7 @@ prepare_cols_list, prepare_init_weights, create_workspace, - # get_weights_path, + get_weights_path, ) import os @@ -56,6 +56,5 @@ def start_aggregator( # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) - # for now create an arbitrary report - with open(report_path, "w") as f: - f.write("agg_accuracy: 1.0") + with open(report_path, 'w') as f: + f.write("IsDone: 1") diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index 67cda94a7..9e0ea346f 100644 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -1,7 +1,16 @@ -git clone https://github.com/securefederatedai/openfl.git -cd openfl -git checkout e6f3f5fd4462307b2c9431184190167aa43d962f -docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . -cd .. -rm -rf openfl +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/securefederatedai/openfl.git + cd openfl + git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl_post/fl/project/utils.py b/examples/fl_post/fl/project/utils.py index e6653ae76..d92add606 100644 --- a/examples/fl_post/fl/project/utils.py +++ b/examples/fl_post/fl/project/utils.py @@ -22,8 +22,6 @@ def get_aggregator_fqdn(fl_workspace): def get_collaborator_cn(): - # TODO: check if there is a way this can cause a collision/race condition - # TODO: from inside the file return os.environ["MEDPERF_PARTICIPANT_LABEL"] @@ -39,7 +37,6 @@ def get_weights_path(fl_workspace): def prepare_plan(plan_path, fl_workspace): target_plan_folder = os.path.join(fl_workspace, "plan") - # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "plan.yaml") @@ -59,7 +56,6 @@ def prepare_cols_list(collaborators_file, fl_workspace): cols_dict = list(cols_dict.keys()) target_plan_folder = os.path.join(fl_workspace, "plan") - # TODO: permissions os.makedirs(target_plan_folder, exist_ok=True) target_plan_file = os.path.join(target_plan_folder, "cols.yaml") with open(target_plan_file, "w") as f: @@ -79,7 +75,6 @@ def prepare_init_weights(input_weights, fl_workspace): target_weights_subpath = get_weights_path(fl_workspace)["init"] target_weights_path = os.path.join(fl_workspace, target_weights_subpath) target_weights_folder = os.path.dirname(target_weights_path) - # TODO: permissions os.makedirs(target_weights_folder, exist_ok=True) os.symlink(file, target_weights_path) @@ -105,7 +100,6 @@ def prepare_node_cert( cert_file = os.path.join(node_cert_folder, cert_file) target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) - # TODO: permissions os.makedirs(target_cert_folder, exist_ok=True) target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") @@ -125,8 +119,7 @@ def prepare_ca_cert(ca_cert_folder, fl_workspace): file = os.path.join(ca_cert_folder, file) target_ca_cert_folder = os.path.join(fl_workspace, "cert") - # TODO: permissions os.makedirs(target_ca_cert_folder, exist_ok=True) target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") - os.symlink(file, target_ca_cert_file) \ No newline at end of file + os.symlink(file, target_ca_cert_file) From 685c64bd55568d26b25f678dc12643e366799e9a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 10 Jul 2024 17:32:43 +0200 Subject: [PATCH 082/242] draft update for the shape mismatch fix --- examples/fl_post/fl/project/hooks.py | 28 +- .../fl_post/fl/project/nnunet_data_setup.py | 292 +++++++++--------- .../fl_post/fl/project/nnunet_model_setup.py | 151 ++++----- examples/fl_post/fl/project/nnunet_setup.py | 119 ++++--- examples/fl_post/fl/project/src/nnunet_v1.py | 13 +- .../fl_post/fl/project/src/runner_nnunetv1.py | 1 + 6 files changed, 291 insertions(+), 313 deletions(-) diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py index dca2c0231..50252f866 100644 --- a/examples/fl_post/fl/project/hooks.py +++ b/examples/fl_post/fl/project/hooks.py @@ -49,18 +49,22 @@ def collaborator_pre_training_hook( os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) - nnunet_setup.main([workspace_folder], - 537, # FIXME: does this need to be set in any particular way? - f'{init_nnunet_directory}/model_initial_checkpoint.model', - f'{init_nnunet_directory}/model_initial_checkpoint.model.pkl', - 'FLPost', - .8, - 'by_subject_time_pair', - '3d_fullres', - 'nnUNetTrainerV2', - '0', - plans_identifier=None, - num_institutions=1, + # this function returns metadata (model weights and config file) to be distributed out of band + # evan should use this without stuff to overwrite/sync so that it produces the correct metdata + # when evan runs, init_model_path, init_model_info_path should be None + # plans_path should also be None (the returned thing will point to where it lives so that it will be synced with others) + + nnunet_setup.main(postopp_pardir=workspace_folder, + three_digit_task_num=537, # FIXME: does this need to be set in any particular way? + init_model_path=f'{init_nnunet_directory}/model_initial_checkpoint.model', + init_model_info_path=f'{init_nnunet_directory}/model_initial_checkpoint.model.pkl', + task_name='FLPost', + percent_train=.8, + split_logic='by_subject_time_pair', + network='3d_fullres', + network_trainer='nnUNetTrainerV2', + fold='0', + plans_path="PATHHHH", # TODO: point this to a mounted file. IT IS NOT AN OPENFL PLAN cuda_device='0', verbose=False) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index c7647f448..07178b4a8 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -1,10 +1,4 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -""" -Contributors: Brandon Edwards, Micah Sheller - -""" import os import subprocess import pickle as pkl @@ -15,7 +9,7 @@ from nnunet.dataset_conversion.utils import generate_dataset_json -from nnunet_model_setup import trim_data_and_setup_nnunet_models +from fl_model_setup import trim_data_and_setup_model num_to_modality = {'_0000': '_brain_t1n.nii.gz', @@ -24,37 +18,24 @@ '_0003': '_brain_t2f.nii.gz'} -def get_subdirs(parent_directory): - subjects = os.listdir(parent_directory) - # print("before filter:", subjects) - subjects = [p for p in subjects if os.path.isdir(os.path.join(parent_directory, p)) and not p.startswith(".")] - # print("after filter:", subjects) - return subjects - - def subject_time_to_mask_path(pardir, subject, timestamp): mask_fname = f'{subject}_{timestamp}_tumorMask_model_0.nii.gz' return os.path.join(pardir, 'labels', '.tumor_segmentation_backup', subject, timestamp,'TumorMasksForQC', mask_fname) -def create_task_folders(first_three_digit_task_num, num_institutions, task_name): - """ - Creates task folders for all simulated instiutions in the federation - """ - nnunet_dst_pardirs = [] - nnunet_images_train_pardirs = [] - nnunet_labels_train_pardirs = [] +def create_task_folders(task_num, task_name, overwrite_nnunet_datadirs): + task = f'Task{str(task_num)}_{task_name}' - task_nums = range(first_three_digit_task_num, first_three_digit_task_num + num_institutions) - tasks = [f'Task{str(num)}_{task_name}' for num in task_nums] - for task in tasks: + # The NNUnet data path is obtained from an environmental variable + nnunet_dst_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data', f'{task}') + + nnunet_images_train_pardir = os.path.join(nnunet_dst_pardir, 'imagesTr') + nnunet_labels_train_pardir = os.path.join(nnunet_dst_pardir, 'labelsTr') - # The NNUnet data path is obtained from an environmental variable - nnunet_dst_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data', f'{task}') - - nnunet_images_train_pardir = os.path.join(nnunet_dst_pardir, 'imagesTr') - nnunet_labels_train_pardir = os.path.join(nnunet_dst_pardir, 'labelsTr') + task_cropped_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_cropped_data', f'{task}') + task_preprocessed_pardir = os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_preprocessed', f'{task}') + if not overwrite_nnunet_datadirs: if os.path.exists(nnunet_images_train_pardir) and os.path.exists(nnunet_labels_train_pardir): raise ValueError(f"Train images pardirs: {nnunet_images_train_pardir} and {nnunet_labels_train_pardir} both already exist. Please move them both and rerun to prevent overwriting.") elif os.path.exists(nnunet_images_train_pardir): @@ -62,19 +43,32 @@ def create_task_folders(first_three_digit_task_num, num_institutions, task_name) elif os.path.exists(nnunet_labels_train_pardir): raise ValueError(f"Train labels pardir: {nnunet_labels_train_pardir} already exists, please move and run again to prevent overwriting.") - os.makedirs(nnunet_images_train_pardir, exist_ok=False) - os.makedirs(nnunet_labels_train_pardir, exist_ok=False) - - nnunet_dst_pardirs.append(nnunet_dst_pardir) - nnunet_images_train_pardirs.append(nnunet_images_train_pardir) - nnunet_labels_train_pardirs.append(nnunet_labels_train_pardir) + if os.path.exists(task_cropped_pardir): + raise ValueError(f"Cropped data pardir: {task_cropped_pardir} already exists, please move and run again to prevent overwriting.") + if os.path.exists(task_preprocessed_pardir): + raise ValueError(f"Preprocessed data pardir: {task_preprocessed_pardir} already exists, please move and run again to prevent overwriting.") + else: + if os.path.exists(task_cropped_pardir): + shutil.rmtree(task_cropped_pardir) + if os.path.exists(task_preprocessed_pardir): + shutil.rmtree(task_preprocessed_pardir) + if os.path.exists(nnunet_images_train_pardir): + shutil.rmtree(nnunet_images_train_pardir) + if os.path.exists(nnunet_labels_train_pardir): + shutil.rmtree(nnunet_labels_train_pardir) + + + os.makedirs(nnunet_images_train_pardir, exist_ok=False) + os.makedirs(nnunet_labels_train_pardir, exist_ok=False) - return task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs + return task, nnunet_dst_pardir, nnunet_images_train_pardir, nnunet_labels_train_pardir -def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_labels_dirpath, nnunet_images_train_pardir, nnunet_labels_train_pardir, timestamp_selection): +def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_labels_dirpath, nnunet_images_train_pardir, nnunet_labels_train_pardir, timestamp_selection, verbose=False): + if verbose: + print(f"\n#######\nsymlinking subject: {postopp_subject_dir}\n########\nPostopp_data_dirpath: {postopp_data_dirpath}\n\n\n\n") postopp_subject_dirpath = os.path.join(postopp_data_dirpath, postopp_subject_dir) - all_timestamps = get_subdirs(postopp_subject_dirpath) + all_timestamps = sorted(list(os.listdir(postopp_subject_dirpath))) if timestamp_selection == 'latest': timestamps = all_timestamps[-1:] elif timestamp_selection == 'earliest': @@ -109,7 +103,7 @@ def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_label def doublecheck_postopp_pardir(postopp_pardir, verbose=False): if verbose: print(f"Checking postopp_pardir: {postopp_pardir}") - postopp_subdirs = get_subdirs(postopp_pardir) + postopp_subdirs = list(os.listdir(postopp_pardir)) if 'data' not in postopp_subdirs: raise ValueError(f"'data' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") if 'labels' not in postopp_subdirs: @@ -186,9 +180,10 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): return train_subject_to_timestamps, val_subject_to_timestamps -def write_splits_file(nnunet_dst_pardir, subject_to_timestamps, percent_train, split_logic, fold, task, splits_fname='splits_final.pkl', verbose=False): +def write_splits_file(subject_to_timestamps, percent_train, split_logic, fold, task, splits_fname='splits_final.pkl', verbose=False): # double check we are in the right folder to modify the splits file splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', splits_fname) + POSTOPP_splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', 'POSTOPP_BACKUP_' + splits_fname) # now split if split_logic == 'by_subject': @@ -211,24 +206,31 @@ def write_splits_file(nnunet_dst_pardir, subject_to_timestamps, percent_train, s # Now write the splits file (note None is put into the folds that we don't use as a safety measure so that no unintended folds are used) new_folds = [None, None, None, None, None] new_folds[int(fold)] = OrderedDict({'train': np.array(train_subjects_list), 'val': np.array(val_subjects_list)}) + with open(splits_fpath, 'wb') as f: pkl.dump(new_folds, f) + # Making an extra copy to test that things are not overwriten later + with open(POSTOPP_splits_fpath, 'wb') as f: + pkl.dump(new_folds, f) + -def setup_nnunet_data(postopp_pardirs, - first_three_digit_task_num, +def setup_fl_data(postopp_pardir, + three_digit_task_num, task_name, percent_train, split_logic, fold, timestamp_selection, - num_institutions, network, network_trainer, - plans_identifier, + local_plans_identifier, + shared_plans_identifier, init_model_path, init_model_info_path, - cuda_device, + cuda_device, + overwrite_nnunet_datadirs, + plans_path=None, verbose=False): """ Generates symlinks to be used for NNUnet training, assuming we already have a @@ -291,127 +293,125 @@ def setup_nnunet_data(postopp_pardirs, │ └── AAAC_extra_2008.12.10_final_seg.nii.gz └── report.yaml - first_three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. + three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. task_name(str) : Any string task name. + percent_train(float) : what percent of data is put into the training data split (rest to val) + split_logic(str) : Determines how train/val split is performed timestamp_selection(str) : Indicates how to determine the timestamp to pick for each subject ID at the source: 'latest', 'earliest', and 'all' are the only ones supported so far + network(str) : Which network is being used for NNUnet + network_trainer(str) : Which network trainer class is being used for NNUnet + local_plans_identifier(str) : Used in the plans file name for a collaborator that will be performing local training to produce an initial model + shared_plans_identifier(str) : Used in the plans file name for creation and dissemination of the shared plan to be used in the federation + init_model_path(str) : Path to the initial model + init_model_info_path(str) : Path to the initial model info (pkl) file + cuda_device(str) : Device to perform training ('cpu' or 'cuda') + overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions. + plans_path(str) : Path to the training plans (pkl) percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since all timestamps associated with the same subject need to land exclusively in either train or val. split_logic(str) : Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair'. fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' - num_institutions(int) : Number of simulated institutions to shard the data into. + timestamp_selection(str) : Determines which timestamps are used for each subject. Can be 'earliest', 'latest', or 'all' verbose(bool) : Debugging output if True. Returns: task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs """ - task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs = \ - create_task_folders(first_three_digit_task_num=first_three_digit_task_num, - num_institutions=num_institutions, - task_name=task_name) - - if len(postopp_pardirs) == 1: - postopp_pardir = postopp_pardirs[0] - doublecheck_postopp_pardir(postopp_pardir, verbose=verbose) - postopp_data_dirpaths = num_institutions * [os.path.join(postopp_pardir, 'data')] - postopp_labels_dirpaths = num_institutions * [os.path.join(postopp_pardir, 'labels')] - - all_subjects = get_subdirs(postopp_data_dirpaths[0]) - subject_shards = [all_subjects[start::num_institutions] for start in range(num_institutions)] - elif len(postopp_pardirs) != num_institutions: - raise ValueError(f"The length of postopp_pardirs must be equal to the number of insitutions needed for the federation, or can be length one and an automated split is peroformed.") - else: - subject_shards = [] - postopp_data_dirpaths = [] - postopp_labels_dirpaths = [] - for postopp_pardir in postopp_pardirs: - doublecheck_postopp_pardir(postopp_pardir, verbose=verbose) - postopp_data_dirpath = os.path.join(postopp_pardir, 'data') - postopp_labels_dirpath = os.path.join(postopp_pardir, 'labels') - postopp_data_dirpaths.append(postopp_data_dirpath) - postopp_labels_dirpaths.append(postopp_labels_dirpath) - subject_shards.append(get_subdirs(postopp_labels_dirpath)) + task, nnunet_dst_pardir, nnunet_images_train_pardir, nnunet_labels_train_pardir = \ + create_task_folders(task_num=three_digit_task_num, task_name=task_name, overwrite_nnunet_datadirs=overwrite_nnunet_datadirs) + + doublecheck_postopp_pardir(postopp_pardir, verbose=verbose) + postopp_data_dirpath = os.path.join(postopp_pardir, 'data') + postopp_labels_dirpath = os.path.join(postopp_pardir, 'labels') + + all_subjects = list(os.listdir(postopp_data_dirpath)) + # Track the subjects and timestamps for each shard - shard_subject_to_timestamps = [] - - for shard_idx, (postopp_subject_dirs, task_num, task, nnunet_dst_pardir, nnunet_images_train_pardir, nnunet_labels_train_pardir, postopp_data_dirpath, postopp_labels_dirpath) in \ - enumerate(zip(subject_shards, task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs, postopp_data_dirpaths, postopp_labels_dirpaths)): - print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA FOR COLLABORATOR {shard_idx} #########\n") - subject_to_timestamps = {} - for postopp_subject_dir in postopp_subject_dirs: - subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, - postopp_data_dirpath=postopp_data_dirpath, - postopp_labels_dirpath=postopp_labels_dirpath, - nnunet_images_train_pardir=nnunet_images_train_pardir, - nnunet_labels_train_pardir=nnunet_labels_train_pardir, - timestamp_selection=timestamp_selection) - shard_subject_to_timestamps.append(subject_to_timestamps) - - # Generate json file for the dataset - print(f"\n######### GENERATING DATA JSON FILE FOR COLLABORATOR {shard_idx} #########\n") - json_path = os.path.join(nnunet_dst_pardir, 'dataset.json') - labels = {0: 'Background', 1: 'Necrosis', 2: 'Edema', 3: 'Enhancing Tumor', 4: 'Cavity'} - # labels = {0: 'Background', 1: 'Necrosis', 2: 'Edema'} - # print(f"{nnunet_images_train_pardir}") - # print(list(os.listdir(nnunet_images_train_pardir))) - + subject_to_timestamps = {} - # from typing import List, Union - # def subfiles(folder: str, join: bool = True, prefix: Union[List[str], str] = None, - # suffix: Union[List[str], str] = None, sort: bool = True) -> List[str]: - # if join: - # l = os.path.join - # else: - # l = lambda x, y: y - - # if prefix is not None and isinstance(prefix, str): - # prefix = [prefix] - # if suffix is not None and isinstance(suffix, str): - # suffix = [suffix] - # print([ i for i in os.listdir(folder)]) - # print([ i for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) ]) - # res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) - # and (prefix is None or any([i.startswith(j) for j in prefix])) - # and (suffix is None or any([i.endswith(j) for j in suffix]))] - - # if sort: - # res.sort() - # return res + print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n") + for postopp_subject_dir in all_subjects: + subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, + postopp_data_dirpath=postopp_data_dirpath, + postopp_labels_dirpath=postopp_labels_dirpath, + nnunet_images_train_pardir=nnunet_images_train_pardir, + nnunet_labels_train_pardir=nnunet_labels_train_pardir, + timestamp_selection=timestamp_selection, + verbose=verbose) - - # uniques = np.unique([i[:-12] for i in subfiles(nnunet_images_train_pardir, suffix='.nii.gz', join=False)]) - # print("UNIQUES::::\n",uniques) - generate_dataset_json(output_file=json_path, - imagesTr_dir=nnunet_images_train_pardir, - imagesTs_dir=None, - modalities=tuple(num_to_modality.keys()), - labels=labels, - dataset_name='RANO Postopp') - - # Now call the os process to preprocess the data - print(f"\n######### OS CALL TO PREPROCESS DATA FOR COLLABORATOR {shard_idx} #########\n") - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{task_num}", "--verify_dataset_integrity"]) - - # trim 2d data if not working with 2d model, and distribute common model architecture across simulated collaborators - trim_data_and_setup_nnunet_models(tasks=tasks, - network=network, - network_trainer=network_trainer, - plans_identifier=plans_identifier, - fold=fold, - init_model_path=init_model_path, - init_model_info_path=init_model_info_path, - cuda_device=cuda_device) - + # Generate json file for the dataset + print(f"\n######### GENERATING DATA JSON FILE #########\n") + json_path = os.path.join(nnunet_dst_pardir, 'dataset.json') + labels = {0: 'Background', 1: 'Necrosis', 2: 'Edema', 3: 'Enhancing Tumor', 4: 'Cavity'} + generate_dataset_json(output_file=json_path, imagesTr_dir=nnunet_images_train_pardir, imagesTs_dir=None, modalities=tuple(num_to_modality.keys()), + labels=labels, dataset_name='RANO Postopp') - for task, subject_to_timestamps in zip(tasks, shard_subject_to_timestamps): - # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val - write_splits_file(nnunet_dst_pardir=nnunet_dst_pardir, - subject_to_timestamps=subject_to_timestamps, + # Now call the os process to preprocess the data + print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") + if plans_path: + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) + plans_identifier_for_model_writing = shared_plans_identifier + else: + # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) + plans_identifier_for_model_writing = local_plans_identifier + + # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val + write_splits_file(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_logic=split_logic, fold=fold, task=task, verbose=verbose) + + # trim 2d data if not working with 2d model, then train an initial model if needed (initial_model_path is None) or write in provided model otherwise + col_paths = {} + col_paths['initial_model_path'], \ + col_paths['final_model_path'], \ + col_paths['initial_model_info_path'], \ + col_paths['final_model_info_path'], \ + col_paths['plans_path'] = trim_data_and_setup_model(task=task, + network=network, + network_trainer=network_trainer, + plans_identifier=plans_identifier_for_model_writing, + fold=fold, + init_model_path=init_model_path, + init_model_info_path=init_model_info_path, + plans_path=plans_path, + cuda_device=cuda_device) + + if not plans_path: + # In this case we have created an initial model with this data, so running preprocesssing again in order to create a 'pretrained' plan similar to what other collaborators will create with our initial plan + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{col_paths['plans_path']}", "-overwrite_plans_identifier", "POSTOPP", "--verify_dataset_integrity", "-no_pp"]) + # Now coying the collaborator paths above to a new location that uses the pretrained planner that will be shared across federation + new_col_paths = {} + new_col_paths['initial_model_path'], \ + new_col_paths['final_model_path'], \ + new_col_paths['initial_model_info_path'], \ + new_col_paths['final_model_info_path'], \ + new_col_paths['plans_path'] = trim_data_and_setup_model(task=task, + network=network, + network_trainer=network_trainer, + plans_identifier=shared_plans_identifier, + fold=fold, + init_model_path=col_paths['initial_model_path'], + init_model_info_path=col_paths['initial_model_info_path'], + plans_path=col_paths['plans_path'], + cuda_device=cuda_device) + + col_paths = new_col_paths + + print(f"\n### ### ### ### ### ### ###\n") + print(f"A MODEL HAS TRAINED. HERE ARE PATHS WHERE FILES CAN BE OBTAINED:\n") + print(f"initial_model_path: {col_paths['initial_model_path']}") + print(f"initial_model_info_path: {col_paths['initial_model_info_path']}") + print(f"final_model_path: {col_paths['final_model_path']}") + print(f"final_model_info_path: {col_paths['final_model_info_path']}") + print(f"plans_path: {col_paths['plans_path']}") + print(f"\n### ### ### ### ### ### ###\n") + return col_paths \ No newline at end of file diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 30d550bdb..4b684a69e 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -1,19 +1,11 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -""" -Contributors: Brandon Edwards, Micah Sheller - -""" - import os import pickle as pkl import shutil -# from nnunet_v1 import train_nnunet +from nnunet_v1 import train_nnunet +from nnunet.paths import default_plans_identifier - -def train_on_task(task, network, network_trainer, fold, cuda_device, continue_training=False, current_epoch=0): +def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): os.environ['CUDA_VISIBLE_DEVICES']=cuda_device print(f"###########\nStarting training for task: {task}\n") train_nnunet(epochs=1, @@ -22,32 +14,34 @@ def train_on_task(task, network, network_trainer, fold, cuda_device, continue_tr task=task, network_trainer = network_trainer, fold=fold, - continue_training=continue_training) + continue_training=continue_training, + p=plans_identifier) -def model_folder(network, task, network_trainer, plans_identifier, fold, results_folder=os.environ['RESULTS_FOLDER']): +def get_model_folder(network, task, network_trainer, plans_identifier, fold, results_folder=os.environ['RESULTS_FOLDER']): return os.path.join(results_folder, 'nnUNet',network, task, network_trainer + '__' + plans_identifier, f'fold_{fold}') -def model_paths_from_folder(model_folder): +def get_col_model_paths(model_folder): return {'initial_model_path': os.path.join(model_folder, 'model_initial_checkpoint.model'), 'final_model_path': os.path.join(model_folder, 'model_final_checkpoint.model'), 'initial_model_info_path': os.path.join(model_folder, 'model_initial_checkpoint.model.pkl'), 'final_model_info_path': os.path.join(model_folder, 'model_final_checkpoint.model.pkl')} -def plan_path(network, task, plans_identifier): +def get_col_plans_path(network, task, plans_identifier): + # returning a dictionary in ordre to incorporate it more easily into another paths dict preprocessed_path = os.environ['nnUNet_preprocessed'] - plan_dirpath = os.path.join(preprocessed_path, task) - plan_path_2d = os.path.join(plan_dirpath, plans_identifier + "_plans_2D.pkl") - plan_path_3d = os.path.join(plan_dirpath, plans_identifier + "_plans_3D.pkl") + plans_write_dirpath = os.path.join(preprocessed_path, task) + plans_write_path_2d = os.path.join(plans_write_dirpath, plans_identifier + "_plans_2D.pkl") + plans_write_path_3d = os.path.join(plans_write_dirpath, plans_identifier + "_plans_3D.pkl") if network =='2d': - plan_path = plan_path_2d + plans_write_path = plans_write_path_2d else: - plan_path = plan_path_3d + plans_write_path = plans_write_path_3d - return plan_path + return {'plans_path': plans_write_path} def delete_2d_data(network, task, plans_identifier): if network == '2d': @@ -55,25 +49,26 @@ def delete_2d_data(network, task, plans_identifier): else: preprocessed_path = os.environ['nnUNet_preprocessed'] plan_dirpath = os.path.join(preprocessed_path, task) - plan_path_2d = os.path.join(plan_dirpath, plans_identifier + "_plans_2D.pkl") + plan_path_2d = os.path.join(plan_dirpath, "nnUNetPlansv2.1_plans_2D.pkl") - if os.path.exists(plan_path_2d): + if os.path.exists(plan_dirpath): # load 2d plan to help construct 2D data directory with open(plan_path_2d, 'rb') as _file: plan_2d = pkl.load(_file) data_dir_2d = os.path.join(plan_dirpath, plan_2d['data_identifier'] + '_stage' + str(list(plan_2d['plans_per_stage'].keys())[-1])) - print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") - shutil.rmtree(data_dir_2d) - + if os.path.exists(data_dir_2d): + print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") + shutil.rmtree(data_dir_2d) +""" def normalize_architecture(reference_plan_path, target_plan_path): - """ - Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path + + # Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path - NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more + # NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more likely leed to an exception. - """ + assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', @@ -109,73 +104,49 @@ def write_pickled_obj(obj, path): # write back to target plan write_pickled_obj(obj=target_plan, path=target_plan_path) +""" +def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'): + """ + Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one + """ -def trim_data_and_setup_nnunet_models(tasks, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, cuda_device='0'): - - col_0_task = tasks[0] - # trim collaborator 0 data if appropriate + # Remove 2D data and 2D data info if appropriate if network != '2d': - delete_2d_data(network=network, task=col_0_task, plans_identifier=plans_identifier) - # get the architecture info from the first collaborator 0 data setup results, and create its model folder (writing the initial model info into it) - col_0_plan_path = plan_path(network=network, task=col_0_task, plans_identifier=plans_identifier) + delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) + + # get or create architecture info - col_0_model_folder = model_folder(network=network, - task=col_0_task, + model_folder = get_model_folder(network=network, + task=task, network_trainer=network_trainer, plans_identifier=plans_identifier, fold=fold) - os.makedirs(col_0_model_folder, exist_ok=False) + if not os.path.exists(model_folder): + os.makedirs(model_folder, exist_ok=False) + + col_paths = get_col_model_paths(model_folder=get_model_folder(network=network, + task=task, + network_trainer=network_trainer, + plans_identifier=plans_identifier, + fold=fold)) + col_paths.update(get_col_plans_path(network=network, task=task, plans_identifier=plans_identifier)) - col_0_model_files_dict = model_paths_from_folder(model_folder=model_folder(network=network, - task=col_0_task, - network_trainer=network_trainer, - plans_identifier=plans_identifier, - fold=fold)) if not init_model_path: - # train collaborator 0 for a single epoch to get an initial model - train_on_task(task=col_0_task, network=network, network_trainer=network_trainer, fold=fold, cuda_device=cuda_device) - # now copy the final model and info from the initial training run into the initial paths - shutil.copyfile(src=col_0_model_files_dict['final_model_path'],dst=col_0_model_files_dict['initial_model_path']) - shutil.copyfile(src=col_0_model_files_dict['final_model_info_path'],dst=col_0_model_files_dict['initial_model_info_path']) + if plans_path: + raise ValueError(f"If the initial model is not provided then we do not expect the plans_path to be provided either (plans file and initial model are sourced the same way).") + # train for a single epoch to get an initial model (this uses the default plans identifier) + train_on_task(task=task, network=network, network_trainer=network_trainer, fold=fold, cuda_device=cuda_device, plans_identifier=default_plans_identifier) + # now copy the trained final model and info into the initial paths + shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path']) + shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path']) else: - print(f"\n######### COPYING INITIAL MODEL FILES INTO COLLABORATOR 0 FOLDERS #########\n") - # Copy initial model and model info into col_0_model_folder - shutil.copyfile(src=init_model_path,dst=col_0_model_files_dict['initial_model_path']) - shutil.copyfile(src=init_model_info_path,dst=col_0_model_files_dict['initial_model_info_path']) - # now copy the initial model also into the final paths - shutil.copyfile(src=col_0_model_files_dict['initial_model_path'],dst=col_0_model_files_dict['final_model_path']) - shutil.copyfile(src=col_0_model_files_dict['initial_model_info_path'],dst=col_0_model_files_dict['final_model_info_path']) - - # now create the model folders for collaborators 1 and upward, populate them with the model files from 0, - # and replace their data directory plan files from the col_0 plan - for col_idx_minus_one, task in enumerate(tasks[1:]): - # trim collaborator data if appropriate - if network != '2d': - delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) - - print(f"\n######### COPYING MODEL INFO FROM COLLABORATOR 0 TO COLLABORATOR {col_idx_minus_one + 1} #########\n") - # replace data directory plan file with one from col_0 - target_plan_path = plan_path(network=network, task=task, plans_identifier=plans_identifier) - normalize_architecture(reference_plan_path=col_0_plan_path, target_plan_path=target_plan_path) - - # create model folder for this collaborator - this_col_model_folder = model_folder(network=network, - task=task, - network_trainer=network_trainer, - plans_identifier=plans_identifier, - fold=fold) - os.makedirs(this_col_model_folder, exist_ok=False) - - # copy model, and model info files from col_0 to this collaborator's model folder - this_col_model_files_dict = model_paths_from_folder(model_folder=model_folder(network=network, - task=task, - network_trainer=network_trainer, - plans_identifier=plans_identifier, - fold=fold)) - # Copy initial and final model and model info from col_0 into this_col_model_folder - shutil.copyfile(src=col_0_model_files_dict['initial_model_path'],dst=this_col_model_files_dict['initial_model_path']) - shutil.copyfile(src=col_0_model_files_dict['final_model_path'],dst=this_col_model_files_dict['final_model_path']) - shutil.copyfile(src=col_0_model_files_dict['initial_model_info_path'],dst=this_col_model_files_dict['initial_model_info_path']) - shutil.copyfile(src=col_0_model_files_dict['final_model_info_path'],dst=this_col_model_files_dict['final_model_info_path']) - \ No newline at end of file + print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\ncol_paths were: {col_paths}\n\n") + shutil.copy(src=plans_path,dst=col_paths['plans_path']) + shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path']) + shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path']) + # now copy these files also into the final paths + shutil.copyfile(src=col_paths['initial_model_path'],dst=col_paths['final_model_path']) + shutil.copyfile(src=col_paths['initial_model_info_path'],dst=col_paths['final_model_info_path']) + + return col_paths['initial_model_path'], col_paths['final_model_path'], col_paths['initial_model_info_path'], col_paths['final_model_info_path'], col_paths['plans_path'] \ No newline at end of file diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index cdc54d9e6..86dd6003f 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -1,34 +1,35 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 +import argparse -""" -Contributors: Brandon Edwards, Micah Sheller +# We will be syncing training across many nodes who independently preprocess data +# In order to do this we will need to sync the training plans (defining the model architecture etc.) +# NNUnet does this by overwriting the plans file which includes a unique alternative plans identifier other than the default one -""" +from nnunet.paths import default_plans_identifier -import argparse +from fl_data_setup import setup_fl_data -from nnunet.paths import default_plans_identifier +local_plans_identifier = default_plans_identifier +shared_plans_identifier = 'nnUNetPlans_pretrained_POSTOPP' -from nnunet_data_setup import setup_nnunet_data -from nnunet_model_setup import trim_data_and_setup_nnunet_models def list_of_strings(arg): return arg.split(',') -def main(postopp_pardirs, - first_three_digit_task_num, - init_model_path, - init_model_info_path, +def main(postopp_pardir, + three_digit_task_num, task_name, - percent_train, - split_logic, - network, - network_trainer, - fold, - plans_identifier=None, + percent_train=0.8, + split_logic='by_subject_time_pair', + network='3d_fullres', + network_trainer='nnUNetTrainerV2', + fold='0', + init_model_path=None, + init_model_info_path=None, + plans_path=None, + local_plans_identifier=local_plans_identifier, + shared_plans_identifier=shared_plans_identifier, + overwrite_nnunet_datadirs=False, timestamp_selection='all', - num_institutions=1, cuda_device='0', verbose=False): """ @@ -40,11 +41,7 @@ def main(postopp_pardirs, should be run using a virtual environment that has nnunet version 1 installed. args: - postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be - equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, - the samples within that single directory will be used to create num_insititutions shards. - If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. - Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: + postopp_pardir(str) : Parent directory for postopp data, which should contain 'data' and 'labels' subdirectories with structure: ├── data │ ├── AAAC_0 │ │ ├── 2008.03.30 @@ -92,11 +89,13 @@ def main(postopp_pardirs, │ └── AAAC_extra_2008.12.10_final_seg.nii.gz └── report.yaml - first_three_digit_task_num(str) : Should start with '5'. If nnunet == N, all N task numbers starting with this number will be used. + three_digit_task_num(str) : Should start with '5' and not collide with other NNUnet task nums on your system. init_model_path (str) : path to initial (pretrained) model file [default None] - must be provided if init_model_info_path is. - [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0 DATA.] + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.] init_model_info_path(str) : path to initial (pretrained) model info pikle file [default None]- must be provided if init_model_path is. - [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0 DATA.] + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.] + plans_path(str) : Path the the NNUnet plan file + [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.] task_name(str) : Name of task that is part of the task name percent_train(float) : The percentage of samples to split into the train portion for the fold specified below (NNUnet makes its own folds but we overwrite all with None except the fold indicated below and put in our own split instead determined by a hard coded split logic default) @@ -104,23 +103,18 @@ def main(postopp_pardirs, network(str) : NNUnet network to be used network_trainer(str) : NNUnet network trainer to be used fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' - plans_identifier(str) : Used in the plans file naming. + local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model. + shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation. + overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions. task_name(str) : Any string task name. timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far - num_institutions(int) : Number of simulated institutions to shard the data into. verbose(bool) : If True, print debugging information. """ - if plans_identifier is None: - plans_identifier = default_plans_identifier - # some argument inspection - task_digit_length = len(str(first_three_digit_task_num)) - if task_digit_length != 3: - raise ValueError(f'The number of digits in {first_three_digit_task_num} should be 3, but it is: {task_digit_length} instead.') - if str(first_three_digit_task_num)[0] != '5': - raise ValueError(f"The three digit task number: {first_three_digit_task_num} should start with 5 to avoid NNUnet repository tasks, but it starts with {first_three_digit_task_num[0]}") + if str(three_digit_task_num)[0] != '5': + raise ValueError(f"The three digit task number: {three_digit_task_num} should start with 5 to avoid NNUnet repository tasks, but it starts with {three_digit_task_num[0]}") if init_model_path or init_model_info_path: if not init_model_path or not init_model_info_path: raise ValueError(f"If either init_model_path or init_model_info_path are provided, they both must be.") @@ -130,48 +124,54 @@ def main(postopp_pardirs, if not init_model_info_path.endswith('.model.pkl'): raise ValueError(f"Initial model info file should end with, 'model.pkl'") - - # task_folder_info is a zipped lists indexed over tasks (collaborators) # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) - setup_nnunet_data(postopp_pardirs=postopp_pardirs, - first_three_digit_task_num=first_three_digit_task_num, + col_paths = setup_fl_data(postopp_pardir=postopp_pardir, + three_digit_task_num=three_digit_task_num, task_name=task_name, percent_train=percent_train, split_logic=split_logic, fold=fold, timestamp_selection=timestamp_selection, - num_institutions=num_institutions, network=network, network_trainer=network_trainer, - plans_identifier=plans_identifier, + local_plans_identifier=local_plans_identifier, + shared_plans_identifier=shared_plans_identifier, init_model_path=init_model_path, - init_model_info_path=init_model_info_path, - cuda_device=cuda_device, + init_model_info_path=init_model_info_path, + plans_path=plans_path, + cuda_device=cuda_device, + overwrite_nnunet_datadirs=overwrite_nnunet_datadirs, verbose=verbose) + + return col_paths if __name__ == '__main__': argparser = argparse.ArgumentParser() argparser.add_argument( - '--postopp_pardirs', - type=list_of_strings, - # nargs='+', - help="Parent directories to postopp data (all should have 'data' and 'labels' subdirectories). Length needs to equal num_institutions or be lengh 1.") + '--postopp_pardir', + type=str, + help="Parent directory to postopp data.") argparser.add_argument( - '--first_three_digit_task_num', + '--three_digit_task_num', type=int, - help="Should start with '5'. If nnunet == N, all N task numbers starting with this number will be used.") + help="Should start with '5'. If fedsim == N, all N task numbers starting with this number will be used.") argparser.add_argument( '--init_model_path', type=str, default=None, - help="Path to initial (pretrained) model file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0's DATA.].") + help="Path to initial (pretrained) model file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.].") argparser.add_argument( '--init_model_info_path', type=str, default=None, - help="Path to initial (pretrained) model info file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION. OTHERWISE ARCHITECTURE IS CHOSEN USING COLLABORATOR 0's DATA.].") + help="Path to initial (pretrained) model info file [ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.].") + argparser.add_argument( + '--plans_path', + type=str, + default=None, + help="Path to the training plan file[ONLY USE IF YOU KNOW THE MODEL ARCHITECTURE MAKES SENSE FOR THE FEDERATION.].") argparser.add_argument( '--task_name', type=str, @@ -206,16 +206,15 @@ def main(postopp_pardirs, type=str, default='all', help="Indicates how to determine the timestamp to pick for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far.") - argparser.add_argument( - '--num_institutions', - type=int, - default=1, - help="Number of symulated insitutions to shard the data into.") argparser.add_argument( '--cuda_device', type=str, default='0', - help="Used for the setting of os.environ['CUDA_VISIBLE_DEVICES']") + help="Used for the setting of os.environ['CUDA_VISIBLE_DEVICES']") + argparser.add_argument( + '--overwrite_nnunet_datadirs', + action='store_true', + help="Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions.") argparser.add_argument( '--verbose', action='store_true', diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 701d02226..78869d4c1 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -20,8 +20,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# EDITED for OpenFL test integration by Brandon Edwards and Micah Sheller - import os import numpy as np @@ -29,7 +27,6 @@ import random from batchgenerators.utilities.file_and_folder_operations import * from nnunet.run.default_configuration import get_default_configuration -from nnunet.paths import default_plans_identifier from nnunet.run.load_pretrained_weights import load_pretrained_weights from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer @@ -41,6 +38,13 @@ ) from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name + +# We will be syncing training across many nodes who independently preprocess data +# In order to do this we will need to sync the training plans (defining the model architecture etc.) +# NNUnet does this by overwriting the plans file which includes a unique alternative plans identifier other than the default one +plans_param = 'nnUNetPlans_pretrained_POSTOPP' +#from nnunet.paths import default_plans_identifier + def seed_everything(seed=1234): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) @@ -59,7 +63,7 @@ def train_nnunet(epochs, continue_training=True, validation_only=False, c=False, - p=default_plans_identifier, + p=plans_param, use_compressed_data=False, deterministic=False, npz=False, @@ -164,7 +168,6 @@ def __init__(self, **kwargs): # force_separate_z = True # else: # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z) - ( plans_file, output_folder_name, diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 6857378b8..43e45b128 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -6,6 +6,7 @@ """ # TODO: Clean up imports +# TODO: ask Micah if this has to be changed (most probably no) import os import subprocess From e9029898d83abb931fdf7550611a5cbb22599f42 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 10 Jul 2024 17:20:09 +0000 Subject: [PATCH 083/242] rename imports --- examples/fl_post/fl/project/nnunet_data_setup.py | 2 +- examples/fl_post/fl/project/nnunet_setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 07178b4a8..30ff9addd 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -9,7 +9,7 @@ from nnunet.dataset_conversion.utils import generate_dataset_json -from fl_model_setup import trim_data_and_setup_model +from nnunet_model_setup import trim_data_and_setup_model num_to_modality = {'_0000': '_brain_t1n.nii.gz', diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 86dd6003f..0106b8f95 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -6,7 +6,7 @@ from nnunet.paths import default_plans_identifier -from fl_data_setup import setup_fl_data +from nnunet_data_setup import setup_fl_data local_plans_identifier = default_plans_identifier shared_plans_identifier = 'nnUNetPlans_pretrained_POSTOPP' From 41b4fd9b6108ba1ec57efe94f9c0db4b6a6a9284 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Jul 2024 15:49:59 +0000 Subject: [PATCH 084/242] bugfixes, add init_model task --- examples/fl_post/fl/mlcube/mlcube.yaml | 8 ++++ examples/fl_post/fl/project/hooks.py | 2 +- examples/fl_post/fl/project/init_model.py | 47 +++++++++++++++++++ examples/fl_post/fl/project/mlcube.py | 18 +++++++ .../fl_post/fl/project/nnunet_model_setup.py | 2 +- examples/fl_post/fl/project/requirements.txt | 1 + .../fl_post/fl/project/src/runner_nnunetv1.py | 8 ++-- 7 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 examples/fl_post/fl/project/init_model.py diff --git a/examples/fl_post/fl/mlcube/mlcube.yaml b/examples/fl_post/fl/mlcube/mlcube.yaml index b13dc4eb8..7fbe19c0a 100644 --- a/examples/fl_post/fl/mlcube/mlcube.yaml +++ b/examples/fl_post/fl/mlcube/mlcube.yaml @@ -46,3 +46,11 @@ tasks: aggregator_config_path: aggregator_config.yaml outputs: plan_path: { type: "file", default: "plan/plan.yaml" } + train_initial_model: + parameters: + inputs: + data_path: data/ + labels_path: labels/ + outputs: + output_logs: logs/ + init_nnunet_directory: init_nnunet_directory/ diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py index 50252f866..a079bd586 100644 --- a/examples/fl_post/fl/project/hooks.py +++ b/examples/fl_post/fl/project/hooks.py @@ -64,7 +64,7 @@ def collaborator_pre_training_hook( network='3d_fullres', network_trainer='nnUNetTrainerV2', fold='0', - plans_path="PATHHHH", # TODO: point this to a mounted file. IT IS NOT AN OPENFL PLAN + plans_path=f'{init_nnunet_directory}/nnUNetPlans_pretrained_POSTOPP_plans_3D.pkl', # NOTE: IT IS NOT AN OPENFL PLAN cuda_device='0', verbose=False) diff --git a/examples/fl_post/fl/project/init_model.py b/examples/fl_post/fl/project/init_model.py new file mode 100644 index 000000000..d436e26d3 --- /dev/null +++ b/examples/fl_post/fl/project/init_model.py @@ -0,0 +1,47 @@ +import os +import shutil + + +def train_initial_model( + data_path, + labels_path, + output_logs, + init_nnunet_directory, +): + # runtime env vars should be set as early as possible + tmpfolder = os.path.join(output_logs, ".tmp") + os.makedirs(tmpfolder, exist_ok=True) + os.environ["RESULTS_FOLDER"] = os.path.join(tmpfolder, "nnUNet_trained_models") + os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") + os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") + import nnunet_setup + + workspace_folder = os.path.join(output_logs, "workspace") + os.makedirs(workspace_folder, exist_ok=True) + + os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) + os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) + + res = nnunet_setup.main( + postopp_pardir=workspace_folder, + three_digit_task_num=537, # FIXME: does this need to be set in any particular way? + init_model_path=None, + init_model_info_path=None, + task_name="FLPost", + percent_train=0.8, + split_logic="by_subject_time_pair", + network="3d_fullres", + network_trainer="nnUNetTrainerV2", + fold="0", + plans_path=None, + cuda_device="0", + verbose=False, + ) + + initial_model_path = res["initial_model_path"] + initial_model_info_path = res["initial_model_info_path"] + plans_path = res["plans_path"] + + shutil.move(initial_model_path, init_nnunet_directory) + shutil.move(initial_model_info_path, init_nnunet_directory) + shutil.move(plans_path, init_nnunet_directory) diff --git a/examples/fl_post/fl/project/mlcube.py b/examples/fl_post/fl/project/mlcube.py index 0fe02af13..ba0140395 100644 --- a/examples/fl_post/fl/project/mlcube.py +++ b/examples/fl_post/fl/project/mlcube.py @@ -12,6 +12,7 @@ collaborator_pre_training_hook, collaborator_post_training_hook, ) +from init_model import train_initial_model app = typer.Typer() @@ -124,5 +125,22 @@ def generate_plan_( generate_plan(training_config_path, aggregator_config_path, plan_path) +@app.command("train_initial_model") +def train_initial_model_( + data_path: str = typer.Option(..., "--data_path"), + labels_path: str = typer.Option(..., "--labels_path"), + output_logs: str = typer.Option(..., "--output_logs"), + init_nnunet_directory: str = typer.Option(..., "--init_nnunet_directory"), +): + _setup(output_logs) + train_initial_model( + data_path=data_path, + labels_path=labels_path, + output_logs=output_logs, + init_nnunet_directory=init_nnunet_directory, + ) + _teardown(output_logs) + + if __name__ == "__main__": app() diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 4b684a69e..4ebd1f9e7 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -2,7 +2,7 @@ import pickle as pkl import shutil -from nnunet_v1 import train_nnunet +from src.nnunet_v1 import train_nnunet from nnunet.paths import default_plans_identifier def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): diff --git a/examples/fl_post/fl/project/requirements.txt b/examples/fl_post/fl/project/requirements.txt index c3eb2a404..8f03308f9 100644 --- a/examples/fl_post/fl/project/requirements.txt +++ b/examples/fl_post/fl/project/requirements.txt @@ -1,3 +1,4 @@ onnx==1.13.0 typer==0.9.0 git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@main#egg=nnunet +numpy==1.26.4 diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 43e45b128..26f184522 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -27,6 +27,8 @@ from .runner_pt_chkpt import PyTorchCheckpointTaskRunner from .nnunet_v1 import train_nnunet +shared_plans_identifier = 'nnUNetPlans_pretrained_POSTOPP' + class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): """An abstract class for PyTorch model based Tasks, where training, validation etc. are processes that pull model state from a PyTorch checkpoint.""" @@ -53,17 +55,17 @@ def __init__(self, super().__init__( checkpoint_path_initial=os.path.join( os.environ['RESULTS_FOLDER'], - f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/', + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__{shared_plans_identifier}/fold_0/', 'model_initial_checkpoint.model' ), checkpoint_path_save=os.path.join( os.environ['RESULTS_FOLDER'], - f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/', + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__{shared_plans_identifier}/fold_0/', 'model_final_checkpoint.model' ), checkpoint_path_load=os.path.join( os.environ['RESULTS_FOLDER'], - f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__nnUNetPlansv2.1/fold_0/', + f'nnUNet/3d_fullres/{nnunet_task}/nnUNetTrainerV2__{shared_plans_identifier}/fold_0/', 'model_final_checkpoint.model' ), **kwargs, From 73f145fed8268710fc4fe0fe622b9be9ecf2ba83 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Jul 2024 16:09:47 +0000 Subject: [PATCH 085/242] change tests config scripts --- examples/fl_post/fl/setup_test.sh | 9 -------- examples/fl_post/fl/setup_test_no_docker.sh | 8 ------- examples/fl_post/fl/sync.sh | 6 +++++ examples/fl_post/fl/test.sh | 25 +++++++++++++++------ examples/fl_post/fl/test_init.sh | 1 + 5 files changed, 25 insertions(+), 24 deletions(-) create mode 100644 examples/fl_post/fl/test_init.sh diff --git a/examples/fl_post/fl/setup_test.sh b/examples/fl_post/fl/setup_test.sh index 542dd7164..72a1c55b9 100644 --- a/examples/fl_post/fl/setup_test.sh +++ b/examples/fl_post/fl/setup_test.sh @@ -82,12 +82,3 @@ cd ../.. cp -r mlcube_col2/workspace/data mlcube_col3/workspace cp -r mlcube_col2/workspace/labels mlcube_col3/workspace - -# weights download -cd mlcube_agg/workspace/ -mkdir additional_files -cd additional_files -wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz -tar -xf init_weights_miccai.tar.gz -rm init_weights_miccai.tar.gz -cd ../../.. diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index 879e84ced..2efdcb069 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -114,11 +114,3 @@ cd ../.. cp -r mlcube_col2/workspace/data mlcube_col3/workspace cp -r mlcube_col2/workspace/labels mlcube_col3/workspace -# weights download -cd mlcube_agg/workspace/ -mkdir additional_files -cd additional_files -wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz -tar -xf init_weights_miccai.tar.gz -rm init_weights_miccai.tar.gz -cd ../../.. diff --git a/examples/fl_post/fl/sync.sh b/examples/fl_post/fl/sync.sh index a5375ce54..53460ec75 100644 --- a/examples/fl_post/fl/sync.sh +++ b/examples/fl_post/fl/sync.sh @@ -4,3 +4,9 @@ cp mlcube/mlcube.yaml mlcube_agg/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col2/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col3/mlcube.yaml + +rm -r mlcube_col2/workspace/additional_files +rm -r mlcube_col3/workspace/additional_files + +cp -r mlcube/workspace/additional_files mlcube_col2/workspace/additional_files +cp -r mlcube/workspace/additional_files mlcube_col3/workspace/additional_files diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 3a154936a..b4b5d42eb 100644 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -8,14 +8,25 @@ cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace # Run nodes AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +COL1="medperf --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" -gnome-terminal -- bash -c "$AGG; bash" -gnome-terminal -- bash -c "$COL1; bash" -gnome-terminal -- bash -c "$COL2; bash" -gnome-terminal -- bash -c "$COL3; bash" +# gnome-terminal -- bash -c "$AGG; bash" +# gnome-terminal -- bash -c "$COL1; bash" +# gnome-terminal -- bash -c "$COL2; bash" +# gnome-terminal -- bash -c "$COL3; bash" +rm agg.log col1.log col2.log col3.log +$AGG >>agg.log & +sleep 6 +$COL2 >>col2.log & +sleep 6 +$COL3 >>col3.log & +# sleep 6 +# $COL2 >> col2.log & +# sleep 6 +# $COL3 >> col3.log & +wait # docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 # python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/test_init.sh b/examples/fl_post/fl/test_init.sh new file mode 100644 index 000000000..3dee25f0a --- /dev/null +++ b/examples/fl_post/fl/test_init.sh @@ -0,0 +1 @@ +medperf --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train_initial_model From 30d394f4f5a11b481c954d2036f93683b8f07e15 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 18 Jul 2024 16:37:52 +0000 Subject: [PATCH 086/242] modify test setup script --- examples/fl_post/fl/setup_test_no_docker.sh | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index 2efdcb069..e422a0d8e 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -94,23 +94,5 @@ echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml # cols file -echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml - -# data download -cd mlcube_col1/workspace/ -wget https://storage.googleapis.com/medperf-storage/testfl/col1_prepared.tar.gz -tar -xf col1_prepared.tar.gz -rm col1_prepared.tar.gz -cd ../.. - -cd mlcube_col2/workspace/ -wget https://storage.googleapis.com/medperf-storage/testfl/col2_prepared.tar.gz -tar -xf col2_prepared.tar.gz -rm col2_prepared.tar.gz -cd ../.. - -cp -r mlcube_col2/workspace/data mlcube_col3/workspace -cp -r mlcube_col2/workspace/labels mlcube_col3/workspace - From 8405bd2d840262925d6045806a2fbfacdd13d678 Mon Sep 17 00:00:00 2001 From: Micah Sheller Date: Sun, 28 Jul 2024 13:07:17 -0700 Subject: [PATCH 087/242] WIP. About to merge openfl PR 996 --- examples/fl_post/fl/build.sh | 13 ++++++++----- examples/fl_post/fl/clean.sh | 0 .../fl/mlcube/workspace/training_config.yaml | 8 +++++++- examples/fl_post/fl/sync.sh | 2 ++ examples/fl_post/fl/test.sh | 2 ++ 5 files changed, 19 insertions(+), 6 deletions(-) mode change 100644 => 100755 examples/fl_post/fl/build.sh mode change 100644 => 100755 examples/fl_post/fl/clean.sh mode change 100644 => 100755 examples/fl_post/fl/sync.sh mode change 100644 => 100755 examples/fl_post/fl/test.sh diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh old mode 100644 new mode 100755 index 9e0ea346f..49f1a2837 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -6,11 +6,14 @@ done BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then - git clone https://github.com/securefederatedai/openfl.git - cd openfl - git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + # git clone https://github.com/securefederatedai/openfl.git + # git clone https://github.com/hasan7n/openfl.git + cd /home/msheller/git/openfl-hasan + # cd openfl + # git checkout e6f3f5fd4462307b2c9431184190167aa43d962f docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . - cd .. - rm -rf openfl + # cd .. + # rm -rf openfl + cd - fi mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl_post/fl/clean.sh b/examples/fl_post/fl/clean.sh old mode 100644 new mode 100755 diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 642f67575..96005daeb 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -35,7 +35,7 @@ network : assigner : defaults : plan/defaults/assigner.yaml - template : openfl.component.RandomGroupedAssigner + template : openfl.component.assigner.DynamicRandomGroupedAssigner settings : task_groups : - name : train_and_validate @@ -57,3 +57,9 @@ tasks : compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 600 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl_post/fl/sync.sh b/examples/fl_post/fl/sync.sh old mode 100644 new mode 100755 index 53460ec75..759666512 --- a/examples/fl_post/fl/sync.sh +++ b/examples/fl_post/fl/sync.sh @@ -5,8 +5,10 @@ cp mlcube/mlcube.yaml mlcube_col1/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col2/mlcube.yaml cp mlcube/mlcube.yaml mlcube_col3/mlcube.yaml +rm -r mlcube_col1/workspace/additional_files rm -r mlcube_col2/workspace/additional_files rm -r mlcube_col3/workspace/additional_files +cp -r mlcube/workspace/additional_files mlcube_col1/workspace/additional_files cp -r mlcube/workspace/additional_files mlcube_col2/workspace/additional_files cp -r mlcube/workspace/additional_files mlcube_col3/workspace/additional_files diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh old mode 100644 new mode 100755 index b4b5d42eb..3343b9f0e --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -12,6 +12,8 @@ COL1="medperf --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF COL2="medperf --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" COL3="medperf --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +# medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & + # gnome-terminal -- bash -c "$AGG; bash" # gnome-terminal -- bash -c "$COL1; bash" # gnome-terminal -- bash -c "$COL2; bash" From 2fe06354ea713227f8a80e0689135a21ea19c864 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 12:02:59 +0200 Subject: [PATCH 088/242] add fl-admin mlcube --- examples/fl/fl/.gitignore | 1 + examples/fl/fl/setup_clean.sh | 1 + examples/fl/fl/setup_test_no_docker.sh | 17 ++ examples/fl/fl/test.sh | 1 + examples/fl/fl_admin/.gitignore | 3 + examples/fl/fl_admin/README.md | 6 + examples/fl/fl_admin/build.sh | 16 ++ examples/fl/fl_admin/clean.sh | 1 + examples/fl/fl_admin/mlcube/mlcube.yaml | 42 +++++ .../fl/fl_admin/mlcube/workspace/plan.yaml | 172 ++++++++++++++++++ examples/fl/fl_admin/project/Dockerfile | 11 ++ examples/fl/fl_admin/project/README.md | 35 ++++ examples/fl/fl_admin/project/admin.py | 60 ++++++ examples/fl/fl_admin/project/mlcube.py | 72 ++++++++ examples/fl/fl_admin/project/requirements.txt | 1 + examples/fl/fl_admin/project/utils.py | 98 ++++++++++ examples/fl/fl_admin/setup_clean.sh | 1 + examples/fl/fl_admin/setup_test_no_docker.sh | 8 + examples/fl/fl_admin/test.sh | 26 +++ examples/fl_post/fl/.gitignore | 3 +- examples/fl_post/fl/setup_clean.sh | 1 + examples/fl_post/fl/setup_test_no_docker.sh | 17 ++ examples/fl_post/fl/test.sh | 1 + 23 files changed, 593 insertions(+), 1 deletion(-) create mode 100644 examples/fl/fl_admin/.gitignore create mode 100644 examples/fl/fl_admin/README.md create mode 100644 examples/fl/fl_admin/build.sh create mode 100644 examples/fl/fl_admin/clean.sh create mode 100644 examples/fl/fl_admin/mlcube/mlcube.yaml create mode 100644 examples/fl/fl_admin/mlcube/workspace/plan.yaml create mode 100644 examples/fl/fl_admin/project/Dockerfile create mode 100644 examples/fl/fl_admin/project/README.md create mode 100644 examples/fl/fl_admin/project/admin.py create mode 100644 examples/fl/fl_admin/project/mlcube.py create mode 100644 examples/fl/fl_admin/project/requirements.txt create mode 100644 examples/fl/fl_admin/project/utils.py create mode 100644 examples/fl/fl_admin/setup_clean.sh create mode 100644 examples/fl/fl_admin/setup_test_no_docker.sh create mode 100644 examples/fl/fl_admin/test.sh diff --git a/examples/fl/fl/.gitignore b/examples/fl/fl/.gitignore index 6bd8bf2e2..167a3778d 100644 --- a/examples/fl/fl/.gitignore +++ b/examples/fl/fl/.gitignore @@ -1,3 +1,4 @@ mlcube_* ca quick* +for_admin diff --git a/examples/fl/fl/setup_clean.sh b/examples/fl/fl/setup_clean.sh index 9f9242024..6615c2968 100644 --- a/examples/fl/fl/setup_clean.sh +++ b/examples/fl/fl/setup_clean.sh @@ -3,3 +3,4 @@ rm -rf ./mlcube_col1 rm -rf ./mlcube_col2 rm -rf ./mlcube_col3 rm -rf ./ca +rm -rf ./for_admin diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh index 879e84ced..71c5237cf 100644 --- a/examples/fl/fl/setup_test_no_docker.sh +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -122,3 +122,20 @@ wget https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.t tar -xf init_weights_miccai.tar.gz rm init_weights_miccai.tar.gz cd ../../.. + +# for admin +ADMIN_CN="admin@example.com" + +mkdir ./for_admin +mkdir ./for_admin/node_cert + +sed -i "/^commonName = /c\commonName = $ADMIN_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $ADMIN_CN" csr.conf +cd for_admin/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf +rm csr.csr +cp ../../ca/root.crt ../ca_cert/ +cd ../.. diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh index 3a154936a..95bd5b673 100644 --- a/examples/fl/fl/test.sh +++ b/examples/fl/fl/test.sh @@ -5,6 +5,7 @@ rm -r ./mlcube_agg/workspace/plan cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace +cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" diff --git a/examples/fl/fl_admin/.gitignore b/examples/fl/fl_admin/.gitignore new file mode 100644 index 000000000..6bd8bf2e2 --- /dev/null +++ b/examples/fl/fl_admin/.gitignore @@ -0,0 +1,3 @@ +mlcube_* +ca +quick* diff --git a/examples/fl/fl_admin/README.md b/examples/fl/fl_admin/README.md new file mode 100644 index 000000000..918f483e3 --- /dev/null +++ b/examples/fl/fl_admin/README.md @@ -0,0 +1,6 @@ +# How to run tests + +- Run `setup_test.sh` just once to create certs and download required data. +- Run `test.sh` to start the aggregator and three collaborators. +- Run `clean.sh` to be able to rerun `test.sh` freshly. +- Run `setup_clean.sh` to clear what has been generated in step 1. diff --git a/examples/fl/fl_admin/build.sh b/examples/fl/fl_admin/build.sh new file mode 100644 index 000000000..9e0ea346f --- /dev/null +++ b/examples/fl/fl_admin/build.sh @@ -0,0 +1,16 @@ +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/securefederatedai/openfl.git + cd openfl + git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl/fl_admin/clean.sh b/examples/fl/fl_admin/clean.sh new file mode 100644 index 000000000..e5a08daf4 --- /dev/null +++ b/examples/fl/fl_admin/clean.sh @@ -0,0 +1 @@ +rm -rf mlcube_admin/workspace/status diff --git a/examples/fl/fl_admin/mlcube/mlcube.yaml b/examples/fl/fl_admin/mlcube/mlcube.yaml new file mode 100644 index 000000000..838726a68 --- /dev/null +++ b/examples/fl/fl_admin/mlcube/mlcube.yaml @@ -0,0 +1,42 @@ +name: FL MLCube +description: FL MLCube +authors: + - { name: MLCommons Medical Working Group } + +platform: + accelerator_count: 0 + +docker: + # Image name + image: mlcommons/medperf-fl-admin:1.0.0 + # Docker build context relative to $MLCUBE_ROOT. Default is `build`. + build_context: "../project" + # Docker file name within docker build context, default is `Dockerfile`. + build_file: "Dockerfile" + +tasks: + get_experiment_status: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + output_status_file: { type: "file", default: "status/status.yaml" } + temp_dir: tmp/ + add_collaborator: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + temp_dir: tmp/ + remove_collaborator: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + temp_dir: tmp/ \ No newline at end of file diff --git a/examples/fl/fl_admin/mlcube/workspace/plan.yaml b/examples/fl/fl_admin/mlcube/workspace/plan.yaml new file mode 100644 index 000000000..0689addf2 --- /dev/null +++ b/examples/fl/fl_admin/mlcube/workspace/plan.yaml @@ -0,0 +1,172 @@ +aggregator: + defaults: plan/defaults/aggregator.yaml + settings: + best_state_path: save/classification_best.pbuf + db_store_rounds: 2 + init_state_path: save/classification_init.pbuf + last_state_path: save/classification_last.pbuf + rounds_to_train: 50 + write_logs: true + admins: + - col3@example.com + allowed_admin_endpoints: + - GetExperimentStatus + template: openfl.component.Aggregator +assigner: + settings: + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + - aggregated_model_validation + - train + - locally_tuned_model_validation + template: openfl.component.RandomGroupedAssigner +collaborator: + settings: + db_store_rounds: 1 + delta_updates: false + opt_treatment: RESET + template: openfl.component.Collaborator +compression_pipeline: + settings: {} + template: openfl.pipelines.NoCompressionPipeline +data_loader: + settings: + feature_shape: + - 128 + - 128 + template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper +network: + settings: + agg_addr: hasan-hp-zbook-15-g3.home + agg_port: 50273 + cert_folder: cert + client_reconnect_interval: 5 + disable_client_auth: false + hash_salt: auto + tls: true + template: openfl.federation.Network +task_runner: + settings: + device: cpu + gandlf_config: + batch_size: 16 + clip_grad: null + clip_mode: null + data_augmentation: {} + data_postprocessing: {} + data_preprocessing: + resize: + - 128 + - 128 + enable_padding: false + grid_aggregator_overlap: crop + in_memory: false + inference_mechanism: + grid_aggregator_overlap: crop + patch_overlap: 0 + learning_rate: 0.001 + loss_function: cel + medcam_enabled: false + memory_save_mode: false + metrics: + accuracy: + average: weighted + mdmc_average: samplewise + multi_class: true + subset_accuracy: false + threshold: 0.5 + balanced_accuracy: None + classification_accuracy: None + f1: + average: weighted + f1: + average: weighted + mdmc_average: samplewise + multi_class: true + threshold: 0.5 + modality: rad + model: + amp: false + architecture: resnet18 + base_filters: 32 + batch_norm: true + class_list: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + dimension: 2 + final_layer: sigmoid + ignore_label_validation: None + n_channels: 3 + norm_type: batch + num_channels: 3 + save_at_every_epoch: false + type: torch + nested_training: + testing: 1 + validation: -5 + num_epochs: 2 + opt: adam + optimizer: + type: adam + output_dir: . + parallel_compute_command: "" + patch_sampler: uniform + patch_size: + - 128 + - 128 + - 1 + patience: 1 + pin_memory_dataloader: false + print_rgb_label_warning: true + q_max_length: 5 + q_num_workers: 0 + q_samples_per_volume: 1 + q_verbose: false + save_masks: false + save_output: false + save_training: false + scaling_factor: 1 + scheduler: + step_size: 0.0002 + type: triangle + track_memory_usage: false + verbose: false + version: + maximum: 0.0.20-dev + minimum: 0.0.20-dev + weighted_loss: true + train_csv: train_path_full.csv + val_csv: val_path_full.csv + template: openfl.federated.task.runner_gandlf.GaNDLFTaskRunner +tasks: + aggregated_model_validation: + function: validate + kwargs: + apply: global + metrics: + - valid_loss + - valid_accuracy + locally_tuned_model_validation: + function: validate + kwargs: + apply: local + metrics: + - valid_loss + - valid_accuracy + settings: {} + train: + function: train + kwargs: + epochs: 1 + metrics: + - loss + - train_accuracy diff --git a/examples/fl/fl_admin/project/Dockerfile b/examples/fl/fl_admin/project/Dockerfile new file mode 100644 index 000000000..fc7afaf2e --- /dev/null +++ b/examples/fl/fl_admin/project/Dockerfile @@ -0,0 +1,11 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube workspace +COPY . /mlcube_project + +ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file diff --git a/examples/fl/fl_admin/project/README.md b/examples/fl/fl_admin/project/README.md new file mode 100644 index 000000000..f9ee6768d --- /dev/null +++ b/examples/fl/fl_admin/project/README.md @@ -0,0 +1,35 @@ +# How to configure container build for your application + +- (Explanation TBD) + +# How to configure container for custom FL software + +- (Explanation TBD) + +# How to build + +- Build the openfl base image: + +```bash +git clone https://github.com/securefederatedai/openfl.git +cd openfl +git checkout e6f3f5fd4462307b2c9431184190167aa43d962f +docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . +cd .. +rm -rf openfl +``` + +- Build the MLCube + +```bash +cd .. +bash build.sh +``` + +# NOTE + +for local experiments, internal IP address or localhost will not work. Use internal fqdn. + +# How to customize + +TBD diff --git a/examples/fl/fl_admin/project/admin.py b/examples/fl/fl_admin/project/admin.py new file mode 100644 index 000000000..04318810c --- /dev/null +++ b/examples/fl/fl_admin/project/admin.py @@ -0,0 +1,60 @@ +from subprocess import check_call +from utils import ( + get_col_label_to_add, + get_col_cn_to_add, + get_col_label_to_remove, + get_col_cn_to_remove, +) + + +def get_experiment_status(workspace_folder, admin_cn, output_status_file): + check_call( + [ + "fx", + "admin", + "get_experiment_status", + "-n", + admin_cn, + "--output_file", + output_status_file, + ], + cwd=workspace_folder, + ) + + +def add_collaborator(workspace_folder, admin_cn): + col_label = get_col_label_to_add() + col_cn = get_col_cn_to_add() + check_call( + [ + "fx", + "admin", + "add_collaborator", + "-n", + admin_cn, + "--col_label", + col_label, + "--col_cn", + col_cn, + ], + cwd=workspace_folder, + ) + + +def remove_collaborator(workspace_folder, admin_cn): + col_label = get_col_label_to_remove() + col_cn = get_col_cn_to_remove() + check_call( + [ + "fx", + "admin", + "remove_collaborator", + "-n", + admin_cn, + "--col_label", + col_label, + "--col_cn", + col_cn, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl/fl_admin/project/mlcube.py b/examples/fl/fl_admin/project/mlcube.py new file mode 100644 index 000000000..9a5019a22 --- /dev/null +++ b/examples/fl/fl_admin/project/mlcube.py @@ -0,0 +1,72 @@ +"""MLCube handler file""" + +import os +import shutil +import typer +from utils import setup_ws +from admin import get_experiment_status, add_collaborator, remove_collaborator + +app = typer.Typer() + + +def _setup(temp_dir): + tmp_folder = os.path.join(temp_dir, ".tmp") + os.makedirs(tmp_folder, exist_ok=True) + # TODO: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmp_folder + os.environ["GRPC_VERBOSITY"] = "ERROR" + + +def _teardown(temp_dir): + tmp_folder = os.path.join(temp_dir, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + +@app.command("get_experiment_status") +def get_experiment_status_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + output_status_file: str = typer.Option(..., "--output_status_file"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + get_experiment_status(workspace_folder, admin_cn, output_status_file) + _teardown(temp_dir) + + +@app.command("add_collaborator") +def add_collaborator_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + add_collaborator(workspace_folder, admin_cn) + _teardown(temp_dir) + + +@app.command("remove_collaborator") +def remove_collaborator_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + remove_collaborator(workspace_folder, admin_cn) + _teardown(temp_dir) + + +if __name__ == "__main__": + app() diff --git a/examples/fl/fl_admin/project/requirements.txt b/examples/fl/fl_admin/project/requirements.txt new file mode 100644 index 000000000..92c979407 --- /dev/null +++ b/examples/fl/fl_admin/project/requirements.txt @@ -0,0 +1 @@ +typer==0.9.0 \ No newline at end of file diff --git a/examples/fl/fl_admin/project/utils.py b/examples/fl/fl_admin/project/utils.py new file mode 100644 index 000000000..640ff72cc --- /dev/null +++ b/examples/fl/fl_admin/project/utils.py @@ -0,0 +1,98 @@ +import os +import shutil + + +def setup_ws(node_cert_folder, ca_cert_folder, plan_path, temp_dir): + workspace_folder = os.path.join(temp_dir, "workspace") + create_workspace(workspace_folder) + prepare_plan(plan_path, workspace_folder) + cn = get_admin_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + return workspace_folder, cn + + +def create_workspace(fl_workspace): + plan_folder = os.path.join(fl_workspace, "plan") + workspace_config = os.path.join(fl_workspace, ".workspace") + defaults_file = os.path.join(plan_folder, "defaults") + + os.makedirs(plan_folder, exist_ok=True) + with open(defaults_file, "w") as f: + f.write("../../workspace/plan/defaults\n\n") + with open(workspace_config, "w") as f: + f.write("current_plan_name: default\n\n") + + +def get_admin_cn(): + return os.environ["MEDPERF_ADMIN_PARTICIPANT_CN"] + + +def get_col_label_to_add(): + return os.environ["MEDPERF_COLLABORATOR_LABEL_TO_ADD"] + + +def get_col_cn_to_add(): + return os.environ["MEDPERF_COLLABORATOR_CN_TO_ADD"] + + +def get_col_label_to_remove(): + return os.environ["MEDPERF_COLLABORATOR_LABEL_TO_REMOVE"] + + +def get_col_cn_to_remove(): + return os.environ["MEDPERF_COLLABORATOR_CN_TO_REMOVE"] + + +def prepare_plan(plan_path, fl_workspace): + target_plan_folder = os.path.join(fl_workspace, "plan") + os.makedirs(target_plan_folder, exist_ok=True) + + target_plan_file = os.path.join(target_plan_folder, "plan.yaml") + shutil.copyfile(plan_path, target_plan_file) + + +def prepare_node_cert( + node_cert_folder, target_cert_folder_name, target_cert_name, fl_workspace +): + error_msg = f"{node_cert_folder} should contain only two files: *.crt and *.key" + + files = os.listdir(node_cert_folder) + file_extensions = [file.split(".")[-1] for file in files] + if len(files) != 2 or sorted(file_extensions) != ["crt", "key"]: + raise RuntimeError(error_msg) + + if files[0].endswith(".crt") and files[1].endswith(".key"): + cert_file = files[0] + key_file = files[1] + else: + key_file = files[0] + cert_file = files[1] + + key_file = os.path.join(node_cert_folder, key_file) + cert_file = os.path.join(node_cert_folder, cert_file) + + target_cert_folder = os.path.join(fl_workspace, "cert", target_cert_folder_name) + os.makedirs(target_cert_folder, exist_ok=True) + target_cert_file = os.path.join(target_cert_folder, f"{target_cert_name}.crt") + target_key_file = os.path.join(target_cert_folder, f"{target_cert_name}.key") + + os.symlink(key_file, target_key_file) + os.symlink(cert_file, target_cert_file) + + +def prepare_ca_cert(ca_cert_folder, fl_workspace): + error_msg = f"{ca_cert_folder} should contain only one file: *.crt" + + files = os.listdir(ca_cert_folder) + file = files[0] + if len(files) != 1 or not file.endswith(".crt"): + raise RuntimeError(error_msg) + + file = os.path.join(ca_cert_folder, file) + + target_ca_cert_folder = os.path.join(fl_workspace, "cert") + os.makedirs(target_ca_cert_folder, exist_ok=True) + target_ca_cert_file = os.path.join(target_ca_cert_folder, "cert_chain.crt") + + os.symlink(file, target_ca_cert_file) diff --git a/examples/fl/fl_admin/setup_clean.sh b/examples/fl/fl_admin/setup_clean.sh new file mode 100644 index 000000000..b82c06ab4 --- /dev/null +++ b/examples/fl/fl_admin/setup_clean.sh @@ -0,0 +1 @@ +rm -rf ./mlcube_admin diff --git a/examples/fl/fl_admin/setup_test_no_docker.sh b/examples/fl/fl_admin/setup_test_no_docker.sh new file mode 100644 index 000000000..cb2beb5b8 --- /dev/null +++ b/examples/fl/fl_admin/setup_test_no_docker.sh @@ -0,0 +1,8 @@ +cp -r ./mlcube ./mlcube_admin + +# Get your node cert folder and ca cert folder from the aggregator setup. Modify paths as needed. +cp -r ../fl/for_admin/node_cert ./mlcube_admin/node_cert +cp -r ../fl/for_admin/ca_cert ./mlcube_admin/ca_cert + +# Note that you should use the same plan used in the federation +# cp ../fl/for_admin/plan.yaml ./mlcube_admin/workspace/plan.yaml diff --git a/examples/fl/fl_admin/test.sh b/examples/fl/fl_admin/test.sh new file mode 100644 index 000000000..3e781d937 --- /dev/null +++ b/examples/fl/fl_admin/test.sh @@ -0,0 +1,26 @@ +# Make sure an aggregator is up somewhere, and it is configured to +# accept admin@example.com as an admin and to allow any endpoints you are willing to test + +# Uncommend and test + +# # GET EXPERIMENT STATUS +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_args=$env_arg1 +# medperf mlcube run --mlcube ./mlcube_admin --task get_experiment_status \ +# -e $env_args + +## ADD COLLABORATOR +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_arg2="MEDPERF_COLLABORATOR_LABEL_TO_ADD=col3@example.com" +# env_arg3="MEDPERF_COLLABORATOR_CN_TO_ADD=col3@example.com" +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task add_collaborator \ +# -e $env_args + +## REMOVE COLLABORATOR +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_arg2="MEDPERF_COLLABORATOR_LABEL_TO_REMOVE=col3@example.com" +# env_arg3="MEDPERF_COLLABORATOR_CN_TO_REMOVE=col3@example.com" +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task remove_collaborator \ +# -e $env_args diff --git a/examples/fl_post/fl/.gitignore b/examples/fl_post/fl/.gitignore index 70b318917..13ab94d3a 100644 --- a/examples/fl_post/fl/.gitignore +++ b/examples/fl_post/fl/.gitignore @@ -2,4 +2,5 @@ mlcube_* ca quick* mlcube/workspace/additional_files/init_nnunet/* -mlcube/workspace/additional_files/init_weights/* \ No newline at end of file +mlcube/workspace/additional_files/init_weights/* +for_admin diff --git a/examples/fl_post/fl/setup_clean.sh b/examples/fl_post/fl/setup_clean.sh index 9f9242024..6615c2968 100644 --- a/examples/fl_post/fl/setup_clean.sh +++ b/examples/fl_post/fl/setup_clean.sh @@ -3,3 +3,4 @@ rm -rf ./mlcube_col1 rm -rf ./mlcube_col2 rm -rf ./mlcube_col3 rm -rf ./ca +rm -rf ./for_admin diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index e422a0d8e..a01ab87ab 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -96,3 +96,20 @@ echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml # cols file echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml + +# for admin +ADMIN_CN="admin@example.com" + +mkdir ./for_admin +mkdir ./for_admin/node_cert + +sed -i "/^commonName = /c\commonName = $ADMIN_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $ADMIN_CN" csr.conf +cd for_admin/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf +rm csr.csr +cp ../../ca/root.crt ../ca_cert/ +cd ../.. diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index b4b5d42eb..e7621af47 100644 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -5,6 +5,7 @@ rm -r ./mlcube_agg/workspace/plan cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace +cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" From 7cc3dd2616807184ad7bea6ad85486510070d54b Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 12:51:23 +0200 Subject: [PATCH 089/242] allow passing col list when starting an event --- cli/medperf/commands/dataset/train.py | 11 +++--- cli/medperf/commands/training/start_event.py | 36 ++++++++++++++++--- cli/medperf/commands/training/training.py | 5 ++- .../traindataset_association/serializers.py | 23 +++++++----- 4 files changed, 55 insertions(+), 20 deletions(-) diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index 231db58d9..867b9cc27 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -54,11 +54,10 @@ def __init__( def prepare(self): self.training_exp = TrainingExp.get(self.training_exp_id) self.ui.print(f"Training Execution: {self.training_exp.name}") - # self.event = TrainingEvent.from_experiment(self.training_exp_id) + self.event = TrainingEvent.from_experiment(self.training_exp_id) self.dataset = Dataset.get(self.data_uid) self.user_email: str = get_medperf_user_data()["email"] - # self.out_logs = os.path.join(self.event.col_out_logs, str(self.dataset.id)) - self.out_logs = os.path.join(self.training_exp.path, str(self.dataset.id)) + self.out_logs = os.path.join(self.event.col_out_logs, str(self.dataset.id)) def validate(self): if self.dataset.id is None: @@ -69,9 +68,9 @@ def validate(self): msg = "The provided dataset is not operational." raise InvalidArgumentError(msg) - # if self.event.finished: - # msg = "The provided training experiment has to start a training event." - # raise InvalidArgumentError(msg) + if self.event.finished: + msg = "The provided training experiment has to start a training event." + raise InvalidArgumentError(msg) def check_existing_outputs(self): msg = ( diff --git a/cli/medperf/commands/training/start_event.py b/cli/medperf/commands/training/start_event.py index e24c2fa48..18f6400aa 100644 --- a/cli/medperf/commands/training/start_event.py +++ b/cli/medperf/commands/training/start_event.py @@ -2,21 +2,32 @@ from medperf.entities.event import TrainingEvent from medperf.utils import approval_prompt, dict_pretty_print, get_participant_label from medperf.exceptions import CleanExit, InvalidArgumentError +import yaml +import os class StartEvent: @classmethod - def run(cls, training_exp_id: int, name: str, approval: bool = False): - submission = cls(training_exp_id, name, approval) + def run( + cls, + training_exp_id: int, + name: str, + participants_list_file: str = None, + approval: bool = False, + ): + submission = cls(training_exp_id, name, participants_list_file, approval) submission.prepare() submission.validate() - submission.create_participants_list() + submission.prepare_participants_list() updated_body = submission.submit() submission.write(updated_body) - def __init__(self, training_exp_id: int, name: str, approval): + def __init__( + self, training_exp_id: int, name: str, participants_list_file: str, approval + ): self.training_exp_id = training_exp_id self.name = name + self.participants_list_file = participants_list_file self.approved = approval def prepare(self): @@ -25,8 +36,23 @@ def prepare(self): def validate(self): if self.training_exp.approval_status != "APPROVED": raise InvalidArgumentError("This experiment has not been approved yet") + if self.participants_list_file is not None: + if not os.path.exists(self.participants_list_file): + raise InvalidArgumentError( + "Provided participants list path does not exist" + ) - def create_participants_list(self): + def prepare_participants_list(self): + if self.participants_list_file is None: + self._prepare_participants_list_from_associations() + else: + self._prepare_participants_list_from_file() + + def _prepare_participants_list_from_file(self): + with open(self.participants_list_file) as f: + self.participants_list = yaml.safe_load(f) + + def _prepare_participants_list_from_associations(self): datasets_with_users = TrainingExp.get_datasets_with_users(self.training_exp_id) participants_list = {} for dataset in datasets_with_users: diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index 36328eb96..b877afead 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -72,10 +72,13 @@ def start_event( ..., "--training_exp_id", "-t", help="UID of the desired benchmark" ), name: str = typer.Option(..., "--name", "-n", help="Name of the benchmark"), + participants_list_file: str = typer.Option( + None, "--participants_list_file", "-p", help="Name of the benchmark" + ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" - StartEvent.run(training_exp_id, name, approval) + StartEvent.run(training_exp_id, name, participants_list_file, approval) config.ui.print("✅ Done!") diff --git a/server/traindataset_association/serializers.py b/server/traindataset_association/serializers.py index 3f2d8d88c..950b073bc 100644 --- a/server/traindataset_association/serializers.py +++ b/server/traindataset_association/serializers.py @@ -10,6 +10,18 @@ ) +def is_approved_participant(training_exp, dataset): + # training_exp event status + event = training_exp.event + if not event or event.finished: + return + + # TODO: modify when we use dataset labels + # TODO: is there a cleaner way? We are making assumptions on the json field structure + participants_list = event.participants.values() + return dataset.owner.email in participants_list + + class ExperimentDatasetListSerializer(serializers.ModelSerializer): class Meta: model = ExperimentDataset @@ -30,13 +42,6 @@ def validate(self, data): "Association requests can be made only on an approved training experiment" ) - # training_exp event status - event = training_exp.event - if event and not event.finished: - raise serializers.ValidationError( - "The training experiment does not currently accept associations" - ) - # dataset state dataset_obj = Dataset.objects.get(pk=dataset) dataset_state = dataset_obj.state @@ -71,7 +76,9 @@ def create(self, validated_data): validated_data["dataset"].owner.id == validated_data["training_exp"].owner.id ) - if same_owner: + if same_owner or is_approved_participant( + validated_data["training_exp"], validated_data["dataset"] + ): validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() return ExperimentDataset.objects.create(**validated_data) From 3f62a0b0c7fd056c3f42f65a109ed8baa6c9ca43 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 15:18:02 +0200 Subject: [PATCH 090/242] add fl admin mlcube to training exp object --- cli/medperf/commands/training/submit.py | 10 +++++-- cli/medperf/commands/training/training.py | 4 +++ cli/medperf/entities/training_exp.py | 1 + ...0002_trainingexperiment_fl_admin_mlcube.py | 26 +++++++++++++++++++ server/training/models.py | 7 +++++ 5 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py index 7da6a96ba..9877497bc 100644 --- a/cli/medperf/commands/training/submit.py +++ b/cli/medperf/commands/training/submit.py @@ -25,7 +25,9 @@ def run(cls, training_exp_info: dict): with ui.interactive(): ui.text = "Getting FL MLCube" - submission.get_mlcube() + submission.get_fl_mlcube() + ui.text = "Getting FL admin MLCube" + submission.get_fl_admin_mlcube() ui.print("> Completed retrieving FL MLCube") ui.text = "Submitting TrainingExp to MedPerf" updated_benchmark_body = submission.submit() @@ -37,10 +39,14 @@ def __init__(self, training_exp_info: dict): self.training_exp = TrainingExp(**training_exp_info) config.tmp_paths.append(self.training_exp.path) - def get_mlcube(self): + def get_fl_mlcube(self): mlcube_id = self.training_exp.fl_mlcube Cube.get(mlcube_id) + def get_fl_admin_mlcube(self): + mlcube_id = self.training_exp.fl_admin_mlcube + Cube.get(mlcube_id) + def submit(self): updated_body = self.training_exp.upload() return updated_body diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index b877afead..717c6ed3c 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -27,6 +27,9 @@ def submit( fl_mlcube: int = typer.Option( ..., "--fl-mlcube", "-m", help="Reference Model MLCube UID" ), + fl_admin_mlcube: int = typer.Option( + ..., "--fl-mlcube", "-a", help="FL admin interface MLCube" + ), operational: bool = typer.Option( False, "--operational", @@ -39,6 +42,7 @@ def submit( "description": description, "docs_url": docs_url, "fl_mlcube": fl_mlcube, + "fl_admin_mlcube": fl_admin_mlcube, "demo_dataset_tarball_url": "link", "demo_dataset_tarball_hash": "hash", "demo_dataset_generated_uid": "uid", diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index 874a2e655..b3cbff37e 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -28,6 +28,7 @@ class TrainingExp(Entity, MedperfSchema, ApprovableSchema, DeployableSchema): demo_dataset_generated_uid: str data_preparation_mlcube: int fl_mlcube: int + fl_admin_mlcube: Optional[int] plan: dict = {} metadata: dict = {} user_metadata: dict = {} diff --git a/server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py b/server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py new file mode 100644 index 000000000..aebbbe3ce --- /dev/null +++ b/server/training/migrations/0002_trainingexperiment_fl_admin_mlcube.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.11 on 2024-07-28 22:15 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("mlcube", "0002_alter_mlcube_unique_together"), + ("training", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="trainingexperiment", + name="fl_admin_mlcube", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="fl_admin_mlcube", + to="mlcube.mlcube", + ), + ), + ] diff --git a/server/training/models.py b/server/training/models.py index a65653119..db1e92c72 100644 --- a/server/training/models.py +++ b/server/training/models.py @@ -32,6 +32,13 @@ class TrainingExperiment(models.Model): on_delete=models.PROTECT, related_name="fl_mlcube", ) + fl_admin_mlcube = models.ForeignKey( + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="fl_admin_mlcube", + blank=True, + null=True, + ) metadata = models.JSONField(default=dict, blank=True, null=True) state = models.CharField(choices=STATES, max_length=100, default="DEVELOPMENT") From 0c6a89fca5be94ed805fcc6ae86c406cb68e90a0 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 18:11:18 +0200 Subject: [PATCH 091/242] some setup/tests changes --- examples/fl/fl/build.sh | 4 ++-- .../fl/mlcube/workspace/training_config.yaml | 15 ++++++++++++++- examples/fl/fl/setup_test_no_docker.sh | 2 +- examples/fl/fl_admin/build.sh | 4 ++-- .../fl/fl_admin/mlcube/workspace/plan.yaml | 18 ++++++++++++------ examples/fl_post/fl/build.sh | 4 ++-- .../fl/mlcube/workspace/training_config.yaml | 14 +++++++++++++- examples/fl_post/fl/setup_test_no_docker.sh | 2 +- 8 files changed, 47 insertions(+), 16 deletions(-) diff --git a/examples/fl/fl/build.sh b/examples/fl/fl/build.sh index 9e0ea346f..96b4c9216 100644 --- a/examples/fl/fl/build.sh +++ b/examples/fl/fl/build.sh @@ -6,9 +6,9 @@ done BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then - git clone https://github.com/securefederatedai/openfl.git + git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + git checkout 8c75ddb252930dd6306885a55d0bb9bd0462c333 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl/fl/mlcube/workspace/training_config.yaml b/examples/fl/fl/mlcube/workspace/training_config.yaml index e5ba18e21..e08e37752 100644 --- a/examples/fl/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl/fl/mlcube/workspace/training_config.yaml @@ -6,9 +6,16 @@ aggregator: last_state_path: save/classification_last.pbuf rounds_to_train: 2 write_logs: true + admins: + - admin@example.com + allowed_admin_endpoints: + - GetExperimentStatus + - AddCollaborator + - RemoveCollaborator template: openfl.component.Aggregator assigner: settings: + template : openfl.component.assigner.DynamicRandomGroupedAssigner task_groups: - name: train_and_validate percentage: 1.0 @@ -162,4 +169,10 @@ tasks: epochs: 1 metrics: - loss - - train_accuracy \ No newline at end of file + - train_accuracy + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 600 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh index 71c5237cf..4c1267cb3 100644 --- a/examples/fl/fl/setup_test_no_docker.sh +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -137,5 +137,5 @@ openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3 openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf rm csr.csr -cp ../../ca/root.crt ../ca_cert/ +cp -r ../../ca/root.crt ../ca_cert/ cd ../.. diff --git a/examples/fl/fl_admin/build.sh b/examples/fl/fl_admin/build.sh index 9e0ea346f..96b4c9216 100644 --- a/examples/fl/fl_admin/build.sh +++ b/examples/fl/fl_admin/build.sh @@ -6,9 +6,9 @@ done BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then - git clone https://github.com/securefederatedai/openfl.git + git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + git checkout 8c75ddb252930dd6306885a55d0bb9bd0462c333 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl/fl_admin/mlcube/workspace/plan.yaml b/examples/fl/fl_admin/mlcube/workspace/plan.yaml index 0689addf2..e08e37752 100644 --- a/examples/fl/fl_admin/mlcube/workspace/plan.yaml +++ b/examples/fl/fl_admin/mlcube/workspace/plan.yaml @@ -1,19 +1,21 @@ aggregator: - defaults: plan/defaults/aggregator.yaml settings: best_state_path: save/classification_best.pbuf db_store_rounds: 2 init_state_path: save/classification_init.pbuf last_state_path: save/classification_last.pbuf - rounds_to_train: 50 + rounds_to_train: 2 write_logs: true admins: - - col3@example.com + - admin@example.com allowed_admin_endpoints: - GetExperimentStatus + - AddCollaborator + - RemoveCollaborator template: openfl.component.Aggregator assigner: settings: + template : openfl.component.assigner.DynamicRandomGroupedAssigner task_groups: - name: train_and_validate percentage: 1.0 @@ -39,8 +41,6 @@ data_loader: template: openfl.federated.data.loader_gandlf.GaNDLFDataLoaderWrapper network: settings: - agg_addr: hasan-hp-zbook-15-g3.home - agg_port: 50273 cert_folder: cert client_reconnect_interval: 5 disable_client_auth: false @@ -51,6 +51,7 @@ task_runner: settings: device: cpu gandlf_config: + memory_save_mode: false # batch_size: 16 clip_grad: null clip_mode: null @@ -69,7 +70,6 @@ task_runner: learning_rate: 0.001 loss_function: cel medcam_enabled: false - memory_save_mode: false metrics: accuracy: average: weighted @@ -170,3 +170,9 @@ tasks: metrics: - loss - train_accuracy + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 600 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index 9e0ea346f..96b4c9216 100644 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -6,9 +6,9 @@ done BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then - git clone https://github.com/securefederatedai/openfl.git + git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout e6f3f5fd4462307b2c9431184190167aa43d962f + git checkout 8c75ddb252930dd6306885a55d0bb9bd0462c333 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 642f67575..1774986e9 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -6,6 +6,12 @@ aggregator : best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf rounds_to_train : 10 + admins: + - admin@example.com + allowed_admin_endpoints: + - GetExperimentStatus + - AddCollaborator + - RemoveCollaborator collaborator : defaults : plan/defaults/collaborator.yaml @@ -35,7 +41,7 @@ network : assigner : defaults : plan/defaults/assigner.yaml - template : openfl.component.RandomGroupedAssigner + template : openfl.component.assigner.DynamicRandomGroupedAssigner settings : task_groups : - name : train_and_validate @@ -57,3 +63,9 @@ tasks : compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml + +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 600 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index a01ab87ab..e9082e11a 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -111,5 +111,5 @@ openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3 openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf rm csr.csr -cp ../../ca/root.crt ../ca_cert/ +cp -r ../../ca/root.crt ../ca_cert/ cd ../.. From 65270f440fbe5925c023ccf63509f0fc6e4f8c96 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 18:12:58 +0200 Subject: [PATCH 092/242] bugfix in tests --- examples/fl/fl/setup_test_no_docker.sh | 2 +- examples/fl_post/fl/setup_test_no_docker.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh index 4c1267cb3..0f46f1225 100644 --- a/examples/fl/fl/setup_test_no_docker.sh +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -137,5 +137,5 @@ openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3 openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf rm csr.csr -cp -r ../../ca/root.crt ../ca_cert/ +cp -r ../../ca/root.crt ../ca_cert/root.crt cd ../.. diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index e9082e11a..da89fc224 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -111,5 +111,5 @@ openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3 openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf rm csr.csr -cp -r ../../ca/root.crt ../ca_cert/ +cp -r ../../ca/root.crt ../ca_cert/root.crt cd ../.. From 251cdd1cc20d486a19f83cab3a9fa2f4d2780cfd Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 18:14:18 +0200 Subject: [PATCH 093/242] tests bugfixes --- examples/fl/fl/setup_test_no_docker.sh | 1 + examples/fl_post/fl/setup_test_no_docker.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh index 0f46f1225..e7625e60d 100644 --- a/examples/fl/fl/setup_test_no_docker.sh +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -137,5 +137,6 @@ openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3 openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf rm csr.csr +mkdir ../ca_cert cp -r ../../ca/root.crt ../ca_cert/root.crt cd ../.. diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index da89fc224..58276a159 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -111,5 +111,6 @@ openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3 openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf rm csr.csr +mkdir ../ca_cert cp -r ../../ca/root.crt ../ca_cert/root.crt cd ../.. From 7525b3523d4abd22cb1db92925311ae0c7198068 Mon Sep 17 00:00:00 2001 From: Micah Sheller Date: Mon, 29 Jul 2024 10:37:05 -0700 Subject: [PATCH 094/242] Added admin config to test plan --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 96005daeb..d20a46aca 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -6,6 +6,12 @@ aggregator : best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf rounds_to_train : 10 + admins: + - col1@example.com + allowed_admin_endpoints: + - GetExperimentStatus + - AddCollaborator + - RemoveCollaborator collaborator : defaults : plan/defaults/collaborator.yaml From 89664e2aa6cb5e605cf249e6aff1d6510c4d9248 Mon Sep 17 00:00:00 2001 From: Micah Sheller Date: Mon, 29 Jul 2024 13:06:53 -0700 Subject: [PATCH 095/242] Added admin endpoints to list of functions that can use the network. Added fix for files in nnunet data directories. --- cli/medperf/entities/cube.py | 3 +++ examples/fl_post/fl/project/nnunet_data_setup.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 9ff5ce0e1..6ea5d7dde 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -258,6 +258,9 @@ def run( "trust", "get_client_cert", "get_server_cert", + "get_experiment_status", + "add_collaborator", + "remove_collaborator", ]: cmd += " --network=none" if config.gpus is not None: diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 30ff9addd..3f24c9515 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -17,6 +17,11 @@ '_0002': '_brain_t1c.nii.gz', '_0003': '_brain_t2f.nii.gz'} +def get_subdirs(parent_directory): + subjects = os.listdir(parent_directory) + subjects = [p for p in subjects if os.path.isdir(os.path.join(parent_directory, p)) and not p.startswith(".")] + return subjects + def subject_time_to_mask_path(pardir, subject, timestamp): mask_fname = f'{subject}_{timestamp}_tumorMask_model_0.nii.gz' @@ -68,7 +73,7 @@ def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_label if verbose: print(f"\n#######\nsymlinking subject: {postopp_subject_dir}\n########\nPostopp_data_dirpath: {postopp_data_dirpath}\n\n\n\n") postopp_subject_dirpath = os.path.join(postopp_data_dirpath, postopp_subject_dir) - all_timestamps = sorted(list(os.listdir(postopp_subject_dirpath))) + all_timestamps = sorted(list(get_subdirs(postopp_subject_dirpath))) if timestamp_selection == 'latest': timestamps = all_timestamps[-1:] elif timestamp_selection == 'earliest': @@ -103,7 +108,7 @@ def symlink_one_subject(postopp_subject_dir, postopp_data_dirpath, postopp_label def doublecheck_postopp_pardir(postopp_pardir, verbose=False): if verbose: print(f"Checking postopp_pardir: {postopp_pardir}") - postopp_subdirs = list(os.listdir(postopp_pardir)) + postopp_subdirs = list(get_subdirs(postopp_pardir)) if 'data' not in postopp_subdirs: raise ValueError(f"'data' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") if 'labels' not in postopp_subdirs: @@ -326,8 +331,7 @@ def setup_fl_data(postopp_pardir, postopp_data_dirpath = os.path.join(postopp_pardir, 'data') postopp_labels_dirpath = os.path.join(postopp_pardir, 'labels') - all_subjects = list(os.listdir(postopp_data_dirpath)) - + all_subjects = list(get_subdirs(postopp_data_dirpath)) # Track the subjects and timestamps for each shard subject_to_timestamps = {} From e23bfaec8413357551d61ebc9392ae3dcae1885a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 29 Jul 2024 23:44:48 +0200 Subject: [PATCH 096/242] add cutofftime task in cube.py --- cli/medperf/entities/cube.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 6ea5d7dde..4a6050cd9 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -261,6 +261,7 @@ def run( "get_experiment_status", "add_collaborator", "remove_collaborator", + "set_straggler_cuttoff_time", ]: cmd += " --network=none" if config.gpus is not None: From 2c4eb432f0648e88742585f391bb2ada0f061ff2 Mon Sep 17 00:00:00 2001 From: Micah Sheller Date: Mon, 29 Jul 2024 16:38:09 -0700 Subject: [PATCH 097/242] Successfully tested with admin endpoint to set straggler handler timeout --- cli/medperf/entities/cube.py | 1 + .../fl/fl/mlcube/workspace/training_config.yaml | 1 + examples/fl/fl_admin/mlcube/mlcube.yaml | 8 ++++++++ examples/fl/fl_admin/project/admin.py | 17 +++++++++++++++++ examples/fl/fl_admin/project/mlcube.py | 17 ++++++++++++++++- examples/fl/fl_admin/project/utils.py | 4 ++++ 6 files changed, 47 insertions(+), 1 deletion(-) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 6ea5d7dde..4a6050cd9 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -261,6 +261,7 @@ def run( "get_experiment_status", "add_collaborator", "remove_collaborator", + "set_straggler_cuttoff_time", ]: cmd += " --network=none" if config.gpus is not None: diff --git a/examples/fl/fl/mlcube/workspace/training_config.yaml b/examples/fl/fl/mlcube/workspace/training_config.yaml index e08e37752..9400964d0 100644 --- a/examples/fl/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl/fl/mlcube/workspace/training_config.yaml @@ -12,6 +12,7 @@ aggregator: - GetExperimentStatus - AddCollaborator - RemoveCollaborator + - SetStragglerCuttoffTime template: openfl.component.Aggregator assigner: settings: diff --git a/examples/fl/fl_admin/mlcube/mlcube.yaml b/examples/fl/fl_admin/mlcube/mlcube.yaml index 838726a68..9ca9ded48 100644 --- a/examples/fl/fl_admin/mlcube/mlcube.yaml +++ b/examples/fl/fl_admin/mlcube/mlcube.yaml @@ -33,6 +33,14 @@ tasks: outputs: temp_dir: tmp/ remove_collaborator: + parameters: + inputs: + node_cert_folder: node_cert/ + ca_cert_folder: ca_cert/ + plan_path: plan.yaml + outputs: + temp_dir: tmp/ + set_straggler_cuttoff_time: parameters: inputs: node_cert_folder: node_cert/ diff --git a/examples/fl/fl_admin/project/admin.py b/examples/fl/fl_admin/project/admin.py index 04318810c..bf05f7e9c 100644 --- a/examples/fl/fl_admin/project/admin.py +++ b/examples/fl/fl_admin/project/admin.py @@ -4,6 +4,7 @@ get_col_cn_to_add, get_col_label_to_remove, get_col_cn_to_remove, + get_straggler_cutoff_time ) @@ -58,3 +59,19 @@ def remove_collaborator(workspace_folder, admin_cn): ], cwd=workspace_folder, ) + + +def set_straggler_cuttoff_time(workspace_folder, admin_cn): + timeout_in_seconds = get_straggler_cutoff_time() + check_call( + [ + "fx", + "admin", + "set_straggler_cuttoff_time", + "-n", + admin_cn, + "--timeout_in_seconds", + timeout_in_seconds, + ], + cwd=workspace_folder, + ) \ No newline at end of file diff --git a/examples/fl/fl_admin/project/mlcube.py b/examples/fl/fl_admin/project/mlcube.py index 9a5019a22..582ad4a4f 100644 --- a/examples/fl/fl_admin/project/mlcube.py +++ b/examples/fl/fl_admin/project/mlcube.py @@ -4,7 +4,7 @@ import shutil import typer from utils import setup_ws -from admin import get_experiment_status, add_collaborator, remove_collaborator +from admin import get_experiment_status, add_collaborator, remove_collaborator, set_straggler_cuttoff_time app = typer.Typer() @@ -68,5 +68,20 @@ def remove_collaborator_( _teardown(temp_dir) +@app.command("set_straggler_cuttoff_time") +def set_straggler_cuttoff_time_( + node_cert_folder: str = typer.Option(..., "--node_cert_folder"), + ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), + plan_path: str = typer.Option(..., "--plan_path"), + temp_dir: str = typer.Option(..., "--temp_dir"), +): + _setup(temp_dir) + workspace_folder, admin_cn = setup_ws( + node_cert_folder, ca_cert_folder, plan_path, temp_dir + ) + set_straggler_cuttoff_time(workspace_folder, admin_cn) + _teardown(temp_dir) + + if __name__ == "__main__": app() diff --git a/examples/fl/fl_admin/project/utils.py b/examples/fl/fl_admin/project/utils.py index 640ff72cc..7843b5f37 100644 --- a/examples/fl/fl_admin/project/utils.py +++ b/examples/fl/fl_admin/project/utils.py @@ -44,6 +44,10 @@ def get_col_cn_to_remove(): return os.environ["MEDPERF_COLLABORATOR_CN_TO_REMOVE"] +def get_straggler_cutoff_time(): + return os.environ["MEDPERF_STRAGGLER_CUTOFF_TIME"] + + def prepare_plan(plan_path, fl_workspace): target_plan_folder = os.path.join(fl_workspace, "plan") os.makedirs(target_plan_folder, exist_ok=True) From e7499256daac669395fd5a2ecc35ca3981f9bec1 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 30 Jul 2024 16:52:40 +0200 Subject: [PATCH 098/242] update numpy version in requirements --- cli/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/requirements.txt b/cli/requirements.txt index 02d8ee05a..94384378c 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -22,6 +22,7 @@ setuptools<=66.1.1 email-validator==2.0.0 auth0-python==4.3.0 pandas==2.1.0 +numpy==1.26.4 watchdog==3.0.0 GitPython==3.1.41 psutil==5.9.8 From fe4bd92c15ac886399cd8057e69615b1380e80b8 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 30 Jul 2024 16:53:04 +0200 Subject: [PATCH 099/242] update testing scripts --- cli/cli_tests_training.sh | 135 ++++++++++------------ cli/medperf/commands/training/submit.py | 3 +- cli/medperf/commands/training/training.py | 2 +- 3 files changed, 67 insertions(+), 73 deletions(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index fcfd7b4f4..04543d0ad 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -193,6 +193,58 @@ checkFailed "submit plan failed" echo "\n" +########################################################## +echo "=====================================" +echo "start event" +echo "=====================================" +echo "testdo@example.com: testdo@example.com" >>./testcols.yaml +echo "testdo2@example.com: testdo2@example.com" >>./testcols.yaml +print_eval medperf training start_event -n event1 -t $TRAINING_UID -p ./testcols.yaml -y +checkFailed "start event failed" +rm ./testcols.yaml + +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get aggregator cert" +echo "=====================================" +print_eval medperf certificate get_server_certificate -t $TRAINING_UID +checkFailed "Get aggregator cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Starting aggregator" +echo "=====================================" +print_eval medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ agg.log 2>&1 & +AGG_PID=$! + +# sleep so that the mlcube is run before we change profiles +sleep 7 + +# Check if the command is still running. +if [ ! -d "/proc/$AGG_PID" ]; then + checkFailed "agg doesn't seem to be running" 1 +fi +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Activate dataowner profile" @@ -351,77 +403,6 @@ fi echo "\n" -########################################################## -echo "=====================================" -echo "Activate modelowner profile" -echo "=====================================" -print_eval medperf profile activate testmodel -checkFailed "testmodel profile activation failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Approve data1 association" -echo "=====================================" -print_eval medperf association approve -t $TRAINING_UID -d $DSET_1_UID -checkFailed "data1 association approval failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Approve data2 association" -echo "=====================================" -print_eval medperf association approve -t $TRAINING_UID -d $DSET_2_UID -checkFailed "data2 association approval failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "start event" -echo "=====================================" -print_eval medperf training start_event -n event1 -t $TRAINING_UID -y -checkFailed "start event failed" - -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Activate aggowner profile" -echo "=====================================" -print_eval medperf profile activate testagg -checkFailed "testagg profile activation failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Get aggregator cert" -echo "=====================================" -print_eval medperf certificate get_server_certificate -t $TRAINING_UID -checkFailed "Get aggregator cert failed" -########################################################## - -echo "\n" - -########################################################## -echo "=====================================" -echo "Starting aggregator" -echo "=====================================" -print_eval medperf aggregator start -t $TRAINING_UID -p $HOSTNAME_ -checkFailed "agg didn't exit successfully" -########################################################## - -echo "\n" - ########################################################## echo "=====================================" echo "Waiting for other prcocesses to exit successfully" @@ -436,6 +417,18 @@ wait $COL1_PID checkFailed "data1 training didn't exit successfully" wait $COL2_PID checkFailed "data2 training didn't exit successfully" +wait $AGG_PID +checkFailed "agg didn't exit successfully" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Activate aggowner profile" +echo "=====================================" +print_eval medperf profile activate testagg +checkFailed "testagg profile activation failed" ########################################################## echo "\n" diff --git a/cli/medperf/commands/training/submit.py b/cli/medperf/commands/training/submit.py index 9877497bc..02221da91 100644 --- a/cli/medperf/commands/training/submit.py +++ b/cli/medperf/commands/training/submit.py @@ -45,7 +45,8 @@ def get_fl_mlcube(self): def get_fl_admin_mlcube(self): mlcube_id = self.training_exp.fl_admin_mlcube - Cube.get(mlcube_id) + if mlcube_id: + Cube.get(mlcube_id) def submit(self): updated_body = self.training_exp.upload() diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index 717c6ed3c..7f44c7b50 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -28,7 +28,7 @@ def submit( ..., "--fl-mlcube", "-m", help="Reference Model MLCube UID" ), fl_admin_mlcube: int = typer.Option( - ..., "--fl-mlcube", "-a", help="FL admin interface MLCube" + None, "--fl-mlcube", "-a", help="FL admin interface MLCube" ), operational: bool = typer.Option( False, From 02689608913c114d07ef0f69904ed8b36b9eae11 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 1 Aug 2024 18:39:44 +0200 Subject: [PATCH 100/242] support gpu drivers v470 --- examples/fl/fl_admin/build.sh | 2 +- examples/fl_post/fl/build.sh | 15 ++++++------- examples/fl_post/fl/project/Dockerfile | 10 +++++++-- examples/fl_post/fl/project/collaborator.py | 5 ++++- examples/fl_post/fl/project/entrypoint.sh | 24 +++++++++++++++++++++ 5 files changed, 43 insertions(+), 13 deletions(-) create mode 100644 examples/fl_post/fl/project/entrypoint.sh diff --git a/examples/fl/fl_admin/build.sh b/examples/fl/fl_admin/build.sh index 96b4c9216..8a5109e63 100644 --- a/examples/fl/fl_admin/build.sh +++ b/examples/fl/fl_admin/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 8c75ddb252930dd6306885a55d0bb9bd0462c333 + git checkout 54f27c61c274f64af3d028f962f62392419cb67e docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index a16113b11..8a5109e63 100755 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -6,14 +6,11 @@ done BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then - # git clone https://github.com/hasan7n/openfl.git - # cd openfl - # git checkout 8c75ddb252930dd6306885a55d0bb9bd0462c333 - # docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . - # # cd .. - # rm -rf openfl - - cd /home/msheller/git/openfl-hasan - cd - + git clone https://github.com/hasan7n/openfl.git + cd openfl + git checkout 54f27c61c274f64af3d028f962f62392419cb67e + docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl fi mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile index 984c4cb3b..d12baa7bb 100644 --- a/examples/fl_post/fl/project/Dockerfile +++ b/examples/fl_post/fl/project/Dockerfile @@ -7,13 +7,19 @@ ENV CUDA_VISIBLE_DEVICES="0" # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ -# RUN pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 COPY ./requirements.txt /mlcube_project/requirements.txt RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt +# Create similar env with cuda118 +RUN apt-get update && apt-get install python3.10-venv -y +RUN python -m venv /cuda118 +RUN /cuda118/bin/pip install --no-cache-dir /openfl +RUN /cuda118/bin/pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 +RUN /cuda118/bin/pip install --no-cache-dir -r /mlcube_project/requirements.txt + # Copy mlcube project folder COPY . /mlcube_project -ENTRYPOINT ["python", "/mlcube_project/mlcube.py"] \ No newline at end of file +ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] \ No newline at end of file diff --git a/examples/fl_post/fl/project/collaborator.py b/examples/fl_post/fl/project/collaborator.py index 38c5048b6..d187a1ab8 100644 --- a/examples/fl_post/fl/project/collaborator.py +++ b/examples/fl_post/fl/project/collaborator.py @@ -26,7 +26,10 @@ def start_collaborator( prepare_ca_cert(ca_cert_folder, workspace_folder) # set log files - check_call(["fx", "collaborator", "start", "-n", cn], cwd=workspace_folder) + check_call( + [os.environ.get("OPENFL_EXECUTABLE", "fx"), "collaborator", "start", "-n", cn], + cwd=workspace_folder, + ) # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) diff --git a/examples/fl_post/fl/project/entrypoint.sh b/examples/fl_post/fl/project/entrypoint.sh new file mode 100644 index 000000000..14d0056da --- /dev/null +++ b/examples/fl_post/fl/project/entrypoint.sh @@ -0,0 +1,24 @@ +PYTHONSCRIPT="import torch; torch.tensor([1.0, 2.0, 3.0, 4.0]).to('cuda')" + +if [ "$1" = "start_aggregator" ] || [ "$1" = "generate_plan" ]; then + # no need for gpu, don't test cuda + python /mlcube_project/mlcube.py $@ +else + echo "Testing which cuda version to use" + python -c "$PYTHONSCRIPT" + if [ "$?" -ne "0" ]; then + echo "cuda 12 failed. Trying with cuda 11.8" + /cuda118/bin/python -c "$PYTHONSCRIPT" + if [ "$?" -ne "0" ]; then + echo "No suppored cuda version satisfies the machine driver. Exiting." + exit 1 + else + echo "cuda 11.8 seems to be working. Will use cuda 11.8" + export OPENFL_EXECUTABLE="/cuda118/bin/fx" + /cuda118/bin/python /mlcube_project/mlcube.py $@ + fi + else + echo "cuda 12 seems to be working. Will use cuda 12" + python /mlcube_project/mlcube.py $@ + fi +fi From 16bf408627d10638ecc96e3b6c8b2d33cf2d516f Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 1 Aug 2024 18:39:57 +0200 Subject: [PATCH 101/242] aggregator copy files bugfix --- examples/fl_post/fl/project/aggregator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/aggregator.py b/examples/fl_post/fl/project/aggregator.py index c0bbeafa1..296adb9af 100644 --- a/examples/fl_post/fl/project/aggregator.py +++ b/examples/fl_post/fl/project/aggregator.py @@ -41,7 +41,8 @@ def start_aggregator( # perhaps investigate overriding plan entries? # NOTE: logs and weights are copied, even if target folders are not empty - copy_tree(os.path.join(workspace_folder, "logs"), output_logs) + if os.path.exists(os.path.join(workspace_folder, "logs")): + copy_tree(os.path.join(workspace_folder, "logs"), output_logs) # NOTE: conversion fails since openfl needs sample data... # weights_paths = get_weights_path(fl_workspace) @@ -56,5 +57,5 @@ def start_aggregator( # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) - with open(report_path, 'w') as f: + with open(report_path, "w") as f: f.write("IsDone: 1") From 8dcbfba66d56c0b86e87cee2457e590aef3f7783 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 1 Aug 2024 20:04:05 +0200 Subject: [PATCH 102/242] singularity option for fl tests --- cli/medperf/commands/mlcube/run.py | 2 ++ examples/fl_post/fl/test.sh | 16 +++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/cli/medperf/commands/mlcube/run.py b/cli/medperf/commands/mlcube/run.py index 75c9fb19e..86cb626b0 100644 --- a/cli/medperf/commands/mlcube/run.py +++ b/cli/medperf/commands/mlcube/run.py @@ -9,4 +9,6 @@ def run_mlcube(mlcube_path, task, out_logs, params, port, env): c.params_path = os.path.join( mlcube_path, config.workspace_path, config.params_filename ) + if config.platform == "singularity": + c._set_image_hash_from_registry() c.run(task, out_logs, port=port, env_dict=env, **params) diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index e0cf803a0..9463c56a1 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -1,5 +1,11 @@ # generate plan and copy it to each node -medperf mlcube run --mlcube ./mlcube_agg --task generate_plan +GENERATE_PLAN_PLATFORM="docker" +AGG_PLATFORM="docker" +COL1_PLATFORM="singularity" +COL2_PLATFORM="docker" +COL3_PLATFORM="docker" + +medperf --platform $GENERATE_PLAN_PLATFORM mlcube run --mlcube ./mlcube_agg --task generate_plan mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace rm -r ./mlcube_agg/workspace/plan cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace @@ -8,10 +14,10 @@ cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes -AGG="medperf mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf --platform $COL2_PLATFORM --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" # medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & From 70e4e56ae5c69d3311557ca681aae025cba71618 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 1 Aug 2024 20:06:01 +0200 Subject: [PATCH 103/242] add test option for admin fl --- examples/fl/fl_admin/test.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/fl/fl_admin/test.sh b/examples/fl/fl_admin/test.sh index 3e781d937..13afbcd0f 100644 --- a/examples/fl/fl_admin/test.sh +++ b/examples/fl/fl_admin/test.sh @@ -24,3 +24,10 @@ # env_args="$env_arg1,$env_arg2,$env_arg3" # medperf mlcube run --mlcube ./mlcube_admin --task remove_collaborator \ # -e $env_args + +# # SET STRAGGLER CUTOFF +# env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" +# env_arg2="MEDPERF_STRAGGLER_CUTOFF_TIME=1200" +# env_args="$env_arg1,$env_arg2" +# medperf mlcube run --mlcube ./mlcube_admin --task set_straggler_cuttoff_time \ +# -e $env_args From b82618b08f8a7b7dae7d67990515ccbb24abe93b Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 15 Aug 2024 19:12:58 -0700 Subject: [PATCH 104/242] enabled specification of batches per epoch for training and validation in both the core NNUnet train function and the runner. Still need to compute these using the desired partial_epoch rate and feed to runner instantiation, as well as have the dummy loader correspondingly modify the collaborator weight (via loader 'length') according to the partial epoch value --- examples/fl_post/fl/project/src/nnunet_v1.py | 10 +++++++++- .../fl_post/fl/project/src/runner_nnunetv1.py | 17 ++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 78869d4c1..4b18168f9 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -55,7 +55,9 @@ def seed_everything(seed=1234): def train_nnunet(epochs, - current_epoch, + current_epoch, + num_train_batches_per_epoch, + num_val_batches_per_epoch, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', @@ -78,6 +80,10 @@ def train_nnunet(epochs, pretrained_weights=None): """ + epochs (int): Number of epochs to train for on top of current epoch + current_epoch (int): Which epoch will be used to grab the model + num_train_batches_per_epoch (int): Number of batches to train over each epoch (batches are sampled with replacement) + num_val_batches_per_epoch (int): Number of batches to validate on each epoch (batches are samples with replacement) task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" validation_only: use this if you want to only run the validation @@ -223,6 +229,8 @@ def __init__(self, **kwargs): ) trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch + trainer.num_batches_per_epoch = num_train_batches_per_epoch + trainer.num_val_batches_per_epoch = num_val_batches_per_epoch # TODO: call validation separately trainer.initialize(not validation_only) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 26f184522..550f478a2 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -34,14 +34,19 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, + num_train_batches_per_epoch, + num_val_batches_per_epoch, nnunet_task=None, config_path=None, **kwargs): """Initialize. Args: - config_path(str) : Path to the configuration file used by the training and validation script. - kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + num_train_batches_per_epoch (int) : Number of batches to be samples (with replacemtnt) for training + num_val_batches_per_epoch (int) : Number of batches to be sampled (with replacement) for validation + nnunet_task (str) : Task string used to identify the data and model folders + config_path(str) : Path to the configuration file used by the training and validation script. + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). TODO: """ @@ -71,6 +76,8 @@ def __init__(self, **kwargs, ) + self.num_train_batches_per_epoch = num_train_batches_per_epoch + self.num_val_batches_per_epoch = num_val_batches_per_epoch self.config_path = config_path @@ -149,7 +156,11 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - train_nnunet(epochs=epochs, current_epoch=current_epoch, task=self.data_loader.get_task_name()) + train_nnunet(epochs=epochs, + current_epoch=current_epoch, + num_train_batches_per_epoch = self.num_train_batches_per_epoch, + num_val_batches_per_epoch = self.num_val_batches_per_epoch, + task=self.data_loader.get_task_name()) # 3. Load metrics from checkpoint (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] From 612b19176b35bc6a6d38caa93d6c271e407cf92a Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 16 Aug 2024 11:27:12 -0700 Subject: [PATCH 105/242] dummy loader now also uses the per collaborator partial_epoch param --- .../fl/project/src/nnunet_dummy_dataloader.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py index 1fe83a4f5..a757b338c 100644 --- a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -12,16 +12,24 @@ import os class NNUNetDummyDataLoader(): - def __init__(self, data_path, p_train): + def __init__(self, data_path, p_train, partial_epoch=1.0): self.task_name = data_path data_base_path = os.path.join(os.environ['nnUNet_preprocessed'], self.task_name) with open(f'{data_base_path}/dataset.json', 'r') as f: data_config = json.load(f) data_size = data_config['numTraining'] + # NOTE: Intended use with PyTorchNNUNetCheckpointTaskRunner where partial_epoch scales down num_train_batches_per_epoch + # and num_val_batches_per_epoch. NNUnet loaders sample batches with replacement. Ignoring rounding (int()), + # the 'data sizes' below are divided by batch_size to obtain the number of batches used per epoch. + # These 'data sizes' therefore establish correct relative weights for train and val result aggregation over collaboarators + # due to the fact that batch_size is equal across all collaborators. In addition, over many rounds each data point + # at a particular collaborator informs the results with equal measure. In particular, the average number of times (over + # repeated runs of the federation) that a particular sample is used for a training or val result + # over the corse of the whole federation is given by the 'data sizes' below. # TODO: determine how nnunet validation splits round - self.train_data_size = int(p_train * data_size) - self.valid_data_size = data_size - self.train_data_size + self.train_data_size = int(partial_epoch * p_train * data_size) + self.valid_data_size = int(partial_epoch * (1 - p_train) * data_size) def get_feature_shape(self): return [1,1,1] From 27f85e5ba37f18d1bb45c127ac3a107c12a54b71 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Sun, 18 Aug 2024 18:58:05 +0200 Subject: [PATCH 106/242] admin commands in medperf --- .../training/get_experiment_status.py | 76 +++++++++++++++++++ cli/medperf/commands/training/training.py | 32 ++++++++ cli/medperf/commands/training/update_plan.py | 74 ++++++++++++++++++ cli/medperf/config.py | 1 + cli/medperf/entities/cube.py | 2 +- cli/medperf/entities/training_exp.py | 1 + examples/fl/fl_admin/mlcube/mlcube.yaml | 2 +- .../fl/fl_admin/mlcube/workspace/plan.yaml | 8 +- examples/fl/fl_admin/project/admin.py | 25 +++--- examples/fl/fl_admin/project/mlcube.py | 13 +++- examples/fl/fl_admin/project/update_plan.py | 20 +++++ examples/fl/fl_admin/project/utils.py | 8 +- examples/fl/fl_admin/test.sh | 8 +- 13 files changed, 240 insertions(+), 30 deletions(-) create mode 100644 cli/medperf/commands/training/get_experiment_status.py create mode 100644 cli/medperf/commands/training/update_plan.py create mode 100644 examples/fl/fl_admin/project/update_plan.py diff --git a/cli/medperf/commands/training/get_experiment_status.py b/cli/medperf/commands/training/get_experiment_status.py new file mode 100644 index 000000000..49c333117 --- /dev/null +++ b/cli/medperf/commands/training/get_experiment_status.py @@ -0,0 +1,76 @@ +from medperf import config +from medperf.account_management.account_management import get_medperf_user_data +from medperf.entities.ca import CA +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.utils import get_pki_assets_path, generate_tmp_path, dict_pretty_print +from medperf.certificates import trust +import yaml + + +class GetExperimentStatus: + @classmethod + def run(cls, training_exp_id: int): + """Starts the aggregation server of a training experiment + + Args: + training_exp_id (int): Training experiment UID. + """ + execution = cls(training_exp_id) + execution.prepare() + execution.prepare_plan() + execution.prepare_pki_assets() + with config.ui.interactive(): + execution.prepare_admin_cube() + execution.get_experiment_status() + execution.print_experiment_status() + + def __init__(self, training_exp_id: int) -> None: + self.training_exp_id = training_exp_id + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Experiment: {self.training_exp.name}") + self.user_email: str = get_medperf_user_data()["email"] + self.status_output = self.training_exp.status_path + self.temp_dir = generate_tmp_path() + + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.admin_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + + def prepare_admin_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_admin_mlcube, "FL Admin") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = ( + "Retrieving and setting up training MLCube. This may take some time." + ) + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def get_experiment_status(self): + env_dict = {"MEDPERF_ADMIN_PARTICIPANT_CN": self.user_email} + params = { + "node_cert_folder": self.admin_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "output_status_file": self.status_output, + "temp_dir": self.temp_dir, + } + + self.ui.text = "Getting training experiment status" + self.cube.run(task="get_experiment_status", env_dict=env_dict, **params) + + def print_experiment_status(self): + with open(self.status_output) as f: + contents = yaml.safe_load(f) + dict_pretty_print(contents) diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index 7f44c7b50..b0cb9a153 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -11,6 +11,8 @@ from medperf.commands.training.close_event import CloseEvent from medperf.commands.list import EntityList from medperf.commands.view import EntityView +from medperf.commands.training.get_experiment_status import GetExperimentStatus +from medperf.commands.training.update_plan import UpdatePlan app = typer.Typer() @@ -86,6 +88,36 @@ def start_event( config.ui.print("✅ Done!") +@app.command("get_experiment_status") +@clean_except +def get_experiment_status( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ) +): + """Runs the benchmark execution step for a given benchmark, prepared dataset and model""" + GetExperimentStatus.run(training_exp_id) + config.ui.print("✅ Done!") + + +@app.command("update_plan") +@clean_except +def update_plan( + training_exp_id: int = typer.Option( + ..., "--training_exp_id", "-t", help="UID of the desired benchmark" + ), + field_name: str = typer.Option( + ..., "--field_name", "-f", help="UID of the desired benchmark" + ), + value: str = typer.Option( + ..., "--value", "-v", help="UID of the desired benchmark" + ), +): + """Runtime-update of a scalar field of the training plan""" + UpdatePlan.run(training_exp_id, field_name, value) + config.ui.print("✅ Done!") + + @app.command("close_event") @clean_except def close_event( diff --git a/cli/medperf/commands/training/update_plan.py b/cli/medperf/commands/training/update_plan.py new file mode 100644 index 000000000..f03676085 --- /dev/null +++ b/cli/medperf/commands/training/update_plan.py @@ -0,0 +1,74 @@ +from medperf import config +from medperf.account_management.account_management import get_medperf_user_data +from medperf.entities.ca import CA +from medperf.entities.training_exp import TrainingExp +from medperf.entities.cube import Cube +from medperf.utils import get_pki_assets_path, generate_tmp_path +from medperf.certificates import trust + + +class GetExperimentStatus: + @classmethod + def run(cls, training_exp_id: int, field_name: str, field_value: str): + """Starts the aggregation server of a training experiment + + Args: + training_exp_id (int): Training experiment UID. + """ + execution = cls(training_exp_id, field_name, field_value) + execution.prepare() + execution.prepare_plan() + execution.prepare_pki_assets() + with config.ui.interactive(): + execution.prepare_admin_cube() + execution.update_plan() + + def __init__(self, training_exp_id: int, field_name: str, field_value: str) -> None: + self.training_exp_id = training_exp_id + self.field_name = field_name + self.field_value = field_value + self.ui = config.ui + + def prepare(self): + self.training_exp = TrainingExp.get(self.training_exp_id) + self.ui.print(f"Training Experiment: {self.training_exp.name}") + self.user_email: str = get_medperf_user_data()["email"] + self.temp_dir = generate_tmp_path() + + def prepare_plan(self): + self.training_exp.prepare_plan() + + def prepare_pki_assets(self): + ca = CA.from_experiment(self.training_exp_id) + trust(ca) + self.admin_pki_assets = get_pki_assets_path(self.user_email, ca.name) + self.ca = ca + + def prepare_admin_cube(self): + self.cube = self.__get_cube(self.training_exp.fl_admin_mlcube, "FL Admin") + + def __get_cube(self, uid: int, name: str) -> Cube: + self.ui.text = ( + "Retrieving and setting up training MLCube. This may take some time." + ) + cube = Cube.get(uid) + cube.download_run_files() + self.ui.print(f"> {name} cube download complete") + return cube + + def update_plan(self): + env_dict = { + "MEDPERF_ADMIN_PARTICIPANT_CN": self.user_email, + "MEDPERF_UPDATE_FIELD_NAME": self.field_name, + "MEDPERF_UPDATE_FIELD_VALUE": self.field_value, + } + + params = { + "node_cert_folder": self.admin_pki_assets, + "ca_cert_folder": self.ca.pki_assets, + "plan_path": self.training_exp.plan_path, + "temp_dir": self.temp_dir, + } + + self.ui.text = "Updating plan" + self.cube.run(task="update_plan", env_dict=env_dict, **params) diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 8c504bd5c..b748db74e 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -170,6 +170,7 @@ training_exps_filename = "training-info.yaml" participants_list_filename = "cols.yaml" training_exp_plan_filename = "plan.yaml" +training_exp_status_filename = "status.yaml" training_report_file = "report.yaml" training_report_folder = "report" training_out_agg_logs = "agg_logs" diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 4a6050cd9..b3ffae880 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -261,7 +261,7 @@ def run( "get_experiment_status", "add_collaborator", "remove_collaborator", - "set_straggler_cuttoff_time", + "update_plan", ]: cmd += " --network=none" if config.gpus is not None: diff --git a/cli/medperf/entities/training_exp.py b/cli/medperf/entities/training_exp.py index b3cbff37e..d9d44c385 100644 --- a/cli/medperf/entities/training_exp.py +++ b/cli/medperf/entities/training_exp.py @@ -63,6 +63,7 @@ def __init__(self, *args, **kwargs): self.generated_uid = self.name self.plan_path = os.path.join(self.path, config.training_exp_plan_filename) + self.status_path = os.path.join(self.path, config.training_exp_status_filename) @classmethod def _Entity__remote_prefilter(cls, filters: dict) -> callable: diff --git a/examples/fl/fl_admin/mlcube/mlcube.yaml b/examples/fl/fl_admin/mlcube/mlcube.yaml index 9ca9ded48..7e0394c59 100644 --- a/examples/fl/fl_admin/mlcube/mlcube.yaml +++ b/examples/fl/fl_admin/mlcube/mlcube.yaml @@ -40,7 +40,7 @@ tasks: plan_path: plan.yaml outputs: temp_dir: tmp/ - set_straggler_cuttoff_time: + update_plan: parameters: inputs: node_cert_folder: node_cert/ diff --git a/examples/fl/fl_admin/mlcube/workspace/plan.yaml b/examples/fl/fl_admin/mlcube/workspace/plan.yaml index e08e37752..2905f2f18 100644 --- a/examples/fl/fl_admin/mlcube/workspace/plan.yaml +++ b/examples/fl/fl_admin/mlcube/workspace/plan.yaml @@ -6,12 +6,12 @@ aggregator: last_state_path: save/classification_last.pbuf rounds_to_train: 2 write_logs: true - admins: - - admin@example.com - allowed_admin_endpoints: + admins_endpoints_mapping: + admin@example.com: - GetExperimentStatus - AddCollaborator - - RemoveCollaborator + - SetStragglerCutoffTime + template: openfl.component.Aggregator assigner: settings: diff --git a/examples/fl/fl_admin/project/admin.py b/examples/fl/fl_admin/project/admin.py index bf05f7e9c..952db58c5 100644 --- a/examples/fl/fl_admin/project/admin.py +++ b/examples/fl/fl_admin/project/admin.py @@ -4,8 +4,10 @@ get_col_cn_to_add, get_col_label_to_remove, get_col_cn_to_remove, - get_straggler_cutoff_time + get_update_field_name, + get_update_value_name, ) +from update_plan import set_straggler_cutoff_time def get_experiment_status(workspace_folder, admin_cn, output_status_file): @@ -61,17 +63,10 @@ def remove_collaborator(workspace_folder, admin_cn): ) -def set_straggler_cuttoff_time(workspace_folder, admin_cn): - timeout_in_seconds = get_straggler_cutoff_time() - check_call( - [ - "fx", - "admin", - "set_straggler_cuttoff_time", - "-n", - admin_cn, - "--timeout_in_seconds", - timeout_in_seconds, - ], - cwd=workspace_folder, - ) \ No newline at end of file +def update_plan(workspace_folder, admin_cn): + field_name = get_update_field_name() + field_value = get_update_value_name() + if field_name == "straggler_handling_policy.settings.straggler_cutoff_time": + set_straggler_cutoff_time(workspace_folder, admin_cn, field_value) + else: + raise ValueError(f"Unsupported field name: {field_name}") diff --git a/examples/fl/fl_admin/project/mlcube.py b/examples/fl/fl_admin/project/mlcube.py index 582ad4a4f..7e412f743 100644 --- a/examples/fl/fl_admin/project/mlcube.py +++ b/examples/fl/fl_admin/project/mlcube.py @@ -4,7 +4,12 @@ import shutil import typer from utils import setup_ws -from admin import get_experiment_status, add_collaborator, remove_collaborator, set_straggler_cuttoff_time +from admin import ( + get_experiment_status, + add_collaborator, + remove_collaborator, + update_plan, +) app = typer.Typer() @@ -68,8 +73,8 @@ def remove_collaborator_( _teardown(temp_dir) -@app.command("set_straggler_cuttoff_time") -def set_straggler_cuttoff_time_( +@app.command("update_plan") +def update_plan_( node_cert_folder: str = typer.Option(..., "--node_cert_folder"), ca_cert_folder: str = typer.Option(..., "--ca_cert_folder"), plan_path: str = typer.Option(..., "--plan_path"), @@ -79,7 +84,7 @@ def set_straggler_cuttoff_time_( workspace_folder, admin_cn = setup_ws( node_cert_folder, ca_cert_folder, plan_path, temp_dir ) - set_straggler_cuttoff_time(workspace_folder, admin_cn) + update_plan(workspace_folder, admin_cn) _teardown(temp_dir) diff --git a/examples/fl/fl_admin/project/update_plan.py b/examples/fl/fl_admin/project/update_plan.py new file mode 100644 index 000000000..2b7013849 --- /dev/null +++ b/examples/fl/fl_admin/project/update_plan.py @@ -0,0 +1,20 @@ +from subprocess import check_call + + +def set_straggler_cutoff_time(workspace_folder, admin_cn, field_value): + if not field_value.isnumeric(): + raise TypeError( + f"Expected an integer for straggler cutoff time, got {field_value}" + ) + check_call( + [ + "fx", + "admin", + "set_straggler_cutoff_time", + "-n", + admin_cn, + "--timeout_in_seconds", + field_value, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl/fl_admin/project/utils.py b/examples/fl/fl_admin/project/utils.py index 7843b5f37..eaa39dc4c 100644 --- a/examples/fl/fl_admin/project/utils.py +++ b/examples/fl/fl_admin/project/utils.py @@ -44,8 +44,12 @@ def get_col_cn_to_remove(): return os.environ["MEDPERF_COLLABORATOR_CN_TO_REMOVE"] -def get_straggler_cutoff_time(): - return os.environ["MEDPERF_STRAGGLER_CUTOFF_TIME"] +def get_update_field_name(): + return os.environ["MEDPERF_UPDATE_FIELD_NAME"] + + +def get_update_value_name(): + return os.environ["MEDPERF_UPDATE_FIELD_VALUE"] def prepare_plan(plan_path, fl_workspace): diff --git a/examples/fl/fl_admin/test.sh b/examples/fl/fl_admin/test.sh index 13afbcd0f..2a45baa16 100644 --- a/examples/fl/fl_admin/test.sh +++ b/examples/fl/fl_admin/test.sh @@ -27,7 +27,9 @@ # # SET STRAGGLER CUTOFF # env_arg1="MEDPERF_ADMIN_PARTICIPANT_CN=admin@example.com" -# env_arg2="MEDPERF_STRAGGLER_CUTOFF_TIME=1200" -# env_args="$env_arg1,$env_arg2" -# medperf mlcube run --mlcube ./mlcube_admin --task set_straggler_cuttoff_time \ +# env_arg2="MEDPERF_UPDATE_FIELD_NAME=straggler_handling_policy.settings.straggler_cutoff_time" +# env_arg3="MEDPERF_UPDATE_FIELD_VALUE=1200" + +# env_args="$env_arg1,$env_arg2,$env_arg3" +# medperf mlcube run --mlcube ./mlcube_admin --task update_plan \ # -e $env_args From 31f39caf4c638aa38f9e310248692d1b489c0804 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 22 Aug 2024 03:20:50 +0200 Subject: [PATCH 107/242] typo --- cli/medperf/commands/training/update_plan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/medperf/commands/training/update_plan.py b/cli/medperf/commands/training/update_plan.py index f03676085..baa064300 100644 --- a/cli/medperf/commands/training/update_plan.py +++ b/cli/medperf/commands/training/update_plan.py @@ -7,7 +7,7 @@ from medperf.certificates import trust -class GetExperimentStatus: +class UpdatePlan: @classmethod def run(cls, training_exp_id: int, field_name: str, field_value: str): """Starts the aggregation server of a training experiment From 44ed99b8818b636e9ab3ec6772045da8c1147884 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 22 Aug 2024 13:33:01 +0200 Subject: [PATCH 108/242] refactor steps in fl mlcube connectivity check, setup certs, now happen before nnunet preprocessing --- examples/fl_post/fl/project/aggregator.py | 31 +-------- examples/fl_post/fl/project/collaborator.py | 38 +++++------ examples/fl_post/fl/project/hooks.py | 71 +++++++-------------- examples/fl_post/fl/project/init_model.py | 14 +--- examples/fl_post/fl/project/mlcube.py | 55 ++++++++-------- examples/fl_post/fl/project/utils.py | 53 +++++++++++++++ 6 files changed, 123 insertions(+), 139 deletions(-) diff --git a/examples/fl_post/fl/project/aggregator.py b/examples/fl_post/fl/project/aggregator.py index 296adb9af..8a1e7f283 100644 --- a/examples/fl_post/fl/project/aggregator.py +++ b/examples/fl_post/fl/project/aggregator.py @@ -1,39 +1,10 @@ -from utils import ( - get_aggregator_fqdn, - prepare_node_cert, - prepare_ca_cert, - prepare_plan, - prepare_cols_list, - prepare_init_weights, - create_workspace, - get_weights_path, -) - import os import shutil from subprocess import check_call from distutils.dir_util import copy_tree -def start_aggregator( - input_weights, - node_cert_folder, - ca_cert_folder, - output_logs, - output_weights, - plan_path, - collaborators, - report_path, -): - - workspace_folder = os.path.join(output_logs, "workspace") - create_workspace(workspace_folder) - prepare_plan(plan_path, workspace_folder) - prepare_cols_list(collaborators, workspace_folder) - prepare_init_weights(input_weights, workspace_folder) - fqdn = get_aggregator_fqdn(workspace_folder) - prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) - prepare_ca_cert(ca_cert_folder, workspace_folder) +def start_aggregator(workspace_folder, output_logs, output_weights, report_path): check_call(["fx", "aggregator", "start"], cwd=workspace_folder) diff --git a/examples/fl_post/fl/project/collaborator.py b/examples/fl_post/fl/project/collaborator.py index d187a1ab8..fb4cdd1c2 100644 --- a/examples/fl_post/fl/project/collaborator.py +++ b/examples/fl_post/fl/project/collaborator.py @@ -1,31 +1,11 @@ -from utils import ( - get_collaborator_cn, - prepare_node_cert, - prepare_ca_cert, - prepare_plan, - create_workspace, -) import os +from utils import get_collaborator_cn import shutil from subprocess import check_call -def start_collaborator( - data_path, - labels_path, - node_cert_folder, - ca_cert_folder, - plan_path, - output_logs, -): - workspace_folder = os.path.join(output_logs, "workspace") - create_workspace(workspace_folder) - prepare_plan(plan_path, workspace_folder) +def start_collaborator(workspace_folder): cn = get_collaborator_cn() - prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) - prepare_ca_cert(ca_cert_folder, workspace_folder) - - # set log files check_call( [os.environ.get("OPENFL_EXECUTABLE", "fx"), "collaborator", "start", "-n", cn], cwd=workspace_folder, @@ -33,3 +13,17 @@ def start_collaborator( # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) + + +def check_connectivity(workspace_folder): + cn = get_collaborator_cn() + check_call( + [ + os.environ.get("OPENFL_EXECUTABLE", "fx"), + "collaborator", + "connectivity_check", + "-n", + cn, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py index a079bd586..516853743 100644 --- a/examples/fl_post/fl/project/hooks.py +++ b/examples/fl_post/fl/project/hooks.py @@ -1,29 +1,8 @@ import os import shutil -import pandas as pd from utils import get_collaborator_cn -def __modify_df(df): - # gandlf convention: labels columns could be "target", "label", "mask" - # subject id column is subjectid. data columns are Channel_0. - # Others could be scalars. # TODO - labels_columns = ["target", "label", "mask"] - data_columns = ["channel_0"] - subject_id_column = "subjectid" - for column in df.columns: - if column.lower() == subject_id_column: - continue - if column.lower() in labels_columns: - prepend_str = "labels/" - elif column.lower() in data_columns: - prepend_str = "data/" - else: - continue - - df[column] = prepend_str + df[column].astype(str) - - def collaborator_pre_training_hook( data_path, labels_path, @@ -32,41 +11,35 @@ def collaborator_pre_training_hook( plan_path, output_logs, init_nnunet_directory, + workspace_folder, ): - # runtime env vars should be set as early as possible - tmpfolder = os.path.join(output_logs, ".tmp") - os.environ["TMPDIR"] = tmpfolder - os.makedirs(tmpfolder, exist_ok=True) - os.environ["RESULTS_FOLDER"] = os.path.join(tmpfolder, "nnUNet_trained_models") - os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") - os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") import nnunet_setup - + cn = get_collaborator_cn() - workspace_folder = os.path.join(output_logs, "workspace") - os.makedirs(workspace_folder, exist_ok=True) os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) # this function returns metadata (model weights and config file) to be distributed out of band - # evan should use this without stuff to overwrite/sync so that it produces the correct metdata + # evan should use this without stuff to overwrite/sync so that it produces the correct metdata # when evan runs, init_model_path, init_model_info_path should be None - # plans_path should also be None (the returned thing will point to where it lives so that it will be synced with others) + # plans_path should also be None (the returned thing will point to where it lives so that it will be synced with others) - nnunet_setup.main(postopp_pardir=workspace_folder, - three_digit_task_num=537, # FIXME: does this need to be set in any particular way? - init_model_path=f'{init_nnunet_directory}/model_initial_checkpoint.model', - init_model_info_path=f'{init_nnunet_directory}/model_initial_checkpoint.model.pkl', - task_name='FLPost', - percent_train=.8, - split_logic='by_subject_time_pair', - network='3d_fullres', - network_trainer='nnUNetTrainerV2', - fold='0', - plans_path=f'{init_nnunet_directory}/nnUNetPlans_pretrained_POSTOPP_plans_3D.pkl', # NOTE: IT IS NOT AN OPENFL PLAN - cuda_device='0', - verbose=False) + nnunet_setup.main( + postopp_pardir=workspace_folder, + three_digit_task_num=537, # FIXME: does this need to be set in any particular way? + init_model_path=f"{init_nnunet_directory}/model_initial_checkpoint.model", + init_model_info_path=f"{init_nnunet_directory}/model_initial_checkpoint.model.pkl", + task_name="FLPost", + percent_train=0.8, + split_logic="by_subject_time_pair", + network="3d_fullres", + network_trainer="nnUNetTrainerV2", + fold="0", + plans_path=f"{init_nnunet_directory}/nnUNetPlans_pretrained_POSTOPP_plans_3D.pkl", # NOTE: IT IS NOT AN OPENFL PLAN + cuda_device="0", + verbose=False, + ) data_config = f"{cn},Task537_FLPost" plan_folder = os.path.join(workspace_folder, "plan") @@ -74,7 +47,8 @@ def collaborator_pre_training_hook( data_config_path = os.path.join(plan_folder, "data.yaml") with open(data_config_path, "w") as f: f.write(data_config) - shutil.copytree('/mlcube_project/src', os.path.join(workspace_folder, 'src')) + shutil.copytree("/mlcube_project/src", os.path.join(workspace_folder, "src")) + def collaborator_post_training_hook( data_path, @@ -83,6 +57,7 @@ def collaborator_post_training_hook( ca_cert_folder, plan_path, output_logs, + workspace_folder, ): pass @@ -96,6 +71,7 @@ def aggregator_pre_training_hook( plan_path, collaborators, report_path, + workspace_folder, ): pass @@ -109,5 +85,6 @@ def aggregator_post_training_hook( plan_path, collaborators, report_path, + workspace_folder, ): pass diff --git a/examples/fl_post/fl/project/init_model.py b/examples/fl_post/fl/project/init_model.py index d436e26d3..c2106b505 100644 --- a/examples/fl_post/fl/project/init_model.py +++ b/examples/fl_post/fl/project/init_model.py @@ -3,22 +3,10 @@ def train_initial_model( - data_path, - labels_path, - output_logs, - init_nnunet_directory, + data_path, labels_path, init_nnunet_directory, workspace_folder ): - # runtime env vars should be set as early as possible - tmpfolder = os.path.join(output_logs, ".tmp") - os.makedirs(tmpfolder, exist_ok=True) - os.environ["RESULTS_FOLDER"] = os.path.join(tmpfolder, "nnUNet_trained_models") - os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") - os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") import nnunet_setup - workspace_folder = os.path.join(output_logs, "workspace") - os.makedirs(workspace_folder, exist_ok=True) - os.symlink(data_path, f"{workspace_folder}/data", target_is_directory=True) os.symlink(labels_path, f"{workspace_folder}/labels", target_is_directory=True) diff --git a/examples/fl_post/fl/project/mlcube.py b/examples/fl_post/fl/project/mlcube.py index ba0140395..14694df94 100644 --- a/examples/fl_post/fl/project/mlcube.py +++ b/examples/fl_post/fl/project/mlcube.py @@ -1,9 +1,7 @@ """MLCube handler file""" -import os -import shutil import typer -from collaborator import start_collaborator +from collaborator import start_collaborator, check_connectivity from aggregator import start_aggregator from plan import generate_plan from hooks import ( @@ -12,23 +10,12 @@ collaborator_pre_training_hook, collaborator_post_training_hook, ) +from utils import generic_setup, generic_teardown, setup_collaborator, setup_aggregator from init_model import train_initial_model app = typer.Typer() -def _setup(output_logs): - tmp_folder = os.path.join(output_logs, ".tmp") - os.makedirs(tmp_folder, exist_ok=True) - # TODO: this should be set before any code imports tempfile - os.environ["TMPDIR"] = tmp_folder - - -def _teardown(output_logs): - tmp_folder = os.path.join(output_logs, ".tmp") - shutil.rmtree(tmp_folder, ignore_errors=True) - - @app.command("train") def train( data_path: str = typer.Option(..., "--data_path"), @@ -39,24 +26,28 @@ def train( output_logs: str = typer.Option(..., "--output_logs"), init_nnunet_directory: str = typer.Option(..., "--init_nnunet_directory"), ): - _setup(output_logs) - collaborator_pre_training_hook( + workspace_folder = generic_setup(output_logs) + setup_collaborator( data_path=data_path, labels_path=labels_path, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, plan_path=plan_path, output_logs=output_logs, - init_nnunet_directory=init_nnunet_directory, + workspace_folder=workspace_folder, ) - start_collaborator( + check_connectivity(workspace_folder) + collaborator_pre_training_hook( data_path=data_path, labels_path=labels_path, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, plan_path=plan_path, output_logs=output_logs, + init_nnunet_directory=init_nnunet_directory, + workspace_folder=workspace_folder, ) + start_collaborator(workspace_folder=workspace_folder) collaborator_post_training_hook( data_path=data_path, labels_path=labels_path, @@ -64,8 +55,9 @@ def train( ca_cert_folder=ca_cert_folder, plan_path=plan_path, output_logs=output_logs, + workspace_folder=workspace_folder, ) - _teardown(output_logs) + generic_teardown(output_logs) @app.command("start_aggregator") @@ -79,8 +71,8 @@ def start_aggregator_( collaborators: str = typer.Option(..., "--collaborators"), report_path: str = typer.Option(..., "--report_path"), ): - _setup(output_logs) - aggregator_pre_training_hook( + workspace_folder = generic_setup(output_logs) + setup_aggregator( input_weights=input_weights, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, @@ -89,8 +81,9 @@ def start_aggregator_( plan_path=plan_path, collaborators=collaborators, report_path=report_path, + workspace_folder=workspace_folder, ) - start_aggregator( + aggregator_pre_training_hook( input_weights=input_weights, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, @@ -99,6 +92,13 @@ def start_aggregator_( plan_path=plan_path, collaborators=collaborators, report_path=report_path, + workspace_folder=workspace_folder, + ) + start_aggregator( + workspace_folder=workspace_folder, + output_logs=output_logs, + output_weights=output_weights, + report_path=report_path, ) aggregator_post_training_hook( input_weights=input_weights, @@ -109,8 +109,9 @@ def start_aggregator_( plan_path=plan_path, collaborators=collaborators, report_path=report_path, + workspace_folder=workspace_folder, ) - _teardown(output_logs) + generic_teardown(output_logs) @app.command("generate_plan") @@ -132,14 +133,14 @@ def train_initial_model_( output_logs: str = typer.Option(..., "--output_logs"), init_nnunet_directory: str = typer.Option(..., "--init_nnunet_directory"), ): - _setup(output_logs) + workspace_folder = generic_setup(output_logs) train_initial_model( data_path=data_path, labels_path=labels_path, - output_logs=output_logs, init_nnunet_directory=init_nnunet_directory, + workspace_folder=workspace_folder, ) - _teardown(output_logs) + generic_teardown(output_logs) if __name__ == "__main__": diff --git a/examples/fl_post/fl/project/utils.py b/examples/fl_post/fl/project/utils.py index d92add606..c656f4d3c 100644 --- a/examples/fl_post/fl/project/utils.py +++ b/examples/fl_post/fl/project/utils.py @@ -3,6 +3,59 @@ import shutil +def generic_setup(output_logs): + tmpfolder = os.path.join(output_logs, ".tmp") + os.makedirs(tmpfolder, exist_ok=True) + # NOTE: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmpfolder + os.environ["RESULTS_FOLDER"] = os.path.join(tmpfolder, "nnUNet_trained_models") + os.environ["nnUNet_raw_data_base"] = os.path.join(tmpfolder, "nnUNet_raw_data_base") + os.environ["nnUNet_preprocessed"] = os.path.join(tmpfolder, "nnUNet_preprocessed") + workspace_folder = os.path.join(output_logs, "workspace") + os.makedirs(workspace_folder, exist_ok=True) + create_workspace(workspace_folder) + return workspace_folder + + +def setup_collaborator( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def setup_aggregator( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + prepare_cols_list(collaborators, workspace_folder) + prepare_init_weights(input_weights, workspace_folder) + fqdn = get_aggregator_fqdn(workspace_folder) + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def generic_teardown(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + def create_workspace(fl_workspace): plan_folder = os.path.join(fl_workspace, "plan") workspace_config = os.path.join(fl_workspace, ".workspace") From f82717ff41b1a3aebd7fa93db99d8a34a1b29ea6 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 22 Aug 2024 13:48:02 +0200 Subject: [PATCH 109/242] modify tests scripts --- examples/fl_post/fl/build.sh | 2 +- examples/fl_post/fl/setup_test_no_docker.sh | 37 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index 8a5109e63..0100f8319 100755 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 54f27c61c274f64af3d028f962f62392419cb67e + git checkout 84819a5d28abff9c196df61cb931342464c0868d docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index 58276a159..745dffc47 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -114,3 +114,40 @@ rm csr.csr mkdir ../ca_cert cp -r ../../ca/root.crt ../ca_cert/root.crt cd ../.. + +# data setup +cd mlcube_col1/workspace +wget https://storage.googleapis.com/medperf-storage/fltest29July/small_test_data1.tar.gz +tar -xf small_test_data1.tar.gz +mv small_test_data1/* . +rm -rf small_test_data1.tar.gz +rm -rf small_test_data1 +cd ../../ + +cd mlcube_col2/workspace +wget https://storage.googleapis.com/medperf-storage/fltest29July/small_test_data2.tar.gz +tar -xf small_test_data2.tar.gz +mv small_test_data2/* . +rm -rf small_test_data2.tar.gz +rm -rf small_test_data2 +cd ../../ + +cd mlcube_col3/workspace +wget https://storage.googleapis.com/medperf-storage/fltest29July/small_test_data3.tar.gz +tar -xf small_test_data3.tar.gz +mv small_test_data3/* . +rm -rf small_test_data3.tar.gz +rm -rf small_test_data3 +cd ../../ + +# weights setup +cd mlcube_agg/workspace +mkdir additional_files +cd additional_files +wget https://storage.googleapis.com/medperf-storage/fltest29July/flpost_add29july.tar.gz +tar -xf flpost_add29july.tar.gz +rm flpost_add29july.tar.gz +cd ../../../ +cp -r mlcube_agg/workspace/additional_files mlcube_col1/workspace/additional_files +cp -r mlcube_agg/workspace/additional_files mlcube_col2/workspace/additional_files +cp -r mlcube_agg/workspace/additional_files mlcube_col3/workspace/additional_files From 3d7b6b1a822ef8ddccb9cf7dd911df6d94294115 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 27 Aug 2024 05:10:33 +0200 Subject: [PATCH 110/242] add error checking in step-ca client --- examples/fl/cert/project/get_cert.sh | 8 ++++++++ examples/fl/cert/project/trust.sh | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/examples/fl/cert/project/get_cert.sh b/examples/fl/cert/project/get_cert.sh index 39aa899a8..c0252b4be 100644 --- a/examples/fl/cert/project/get_cert.sh +++ b/examples/fl/cert/project/get_cert.sh @@ -83,5 +83,13 @@ step ca certificate --ca-url $CA_ADDRESS:$CA_PORT \ $PROVISIONER_ARGS \ $MEDPERF_INPUT_CN $cert_path $key_path +EXITSTATUS="$?" +if [ $EXITSTATUS -ne "0" ]; then + echo "Failed to get the certificate" + # cleanup + rm -rf $STEPPATH + exit 1 +fi + # cleanup rm -rf $STEPPATH diff --git a/examples/fl/cert/project/trust.sh b/examples/fl/cert/project/trust.sh index ceb2a303a..c33a96ea9 100644 --- a/examples/fl/cert/project/trust.sh +++ b/examples/fl/cert/project/trust.sh @@ -47,6 +47,13 @@ if [ -n "$CA_FINGERPRINT" ]; then else wget -O $pki_assets/root_ca.crt $CA_ADDRESS:$CA_PORT/roots.pem fi +EXITSTATUS="$?" +if [ $EXITSTATUS -ne "0" ]; then + echo "Failed to retrieve the root certificate" + # cleanup + rm -rf $STEPPATH + exit 1 +fi # cleanup rm -rf $STEPPATH From 4e59c4abe13a39b6a2ea1ae73d7bdebb362eaa43 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 27 Aug 2024 18:35:43 -0700 Subject: [PATCH 111/242] some clean up, as well as backing off to only pass partial_epoch to runner and train method, inferring number of batches from data size and batch size via the trainer object --- .../fl_post/fl/project/nnunet_model_setup.py | 44 ------------------- examples/fl_post/fl/project/nnunet_setup.py | 3 +- examples/fl_post/fl/project/src/nnunet_v1.py | 14 ++++-- .../fl_post/fl/project/src/runner_nnunetv1.py | 12 ++--- 4 files changed, 16 insertions(+), 57 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 4ebd1f9e7..cd9412790 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -60,51 +60,7 @@ def delete_2d_data(network, task, plans_identifier): print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") shutil.rmtree(data_dir_2d) -""" -def normalize_architecture(reference_plan_path, target_plan_path): - - # Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path - # NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more - likely leed to an exception. - - - - assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features', - 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', - 'transpose_backward', 'preprocessor_name', 'conv_per_stage', 'data_identifier'] - copy_over_keys = ['plans_per_stage'] - nullify_keys = ['original_spacings', 'original_sizes'] - leave_alone_keys = ['list_of_npz_files', 'preprocessed_data_folder', 'dataset_properties'] - - - # check I got all keys here - assert set(copy_over_keys).union(set(assert_same_keys)).union(set(nullify_keys)).union(set(leave_alone_keys)) == set(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage']) - - def get_pickle_obj(path): - with open(path, 'rb') as _file: - plan= pkl.load(_file) - return plan - - def write_pickled_obj(obj, path): - with open(path, 'wb') as _file: - pkl.dump(obj, _file) - - reference_plan = get_pickle_obj(path=reference_plan_path) - target_plan = get_pickle_obj(path=target_plan_path) - - for key in assert_same_keys: - if reference_plan[key] != target_plan[key]: - raise ValueError(f"normalize architecture failed since the reference and target plans differed in at least key: {key}") - for key in copy_over_keys: - target_plan[key] = reference_plan[key] - for key in nullify_keys: - target_plan[key] = None - # leave alone keys are left alone :) - - # write back to target plan - write_pickled_obj(obj=target_plan, path=target_plan_path) -""" def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'): """ diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 0106b8f95..5ce499996 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -105,7 +105,7 @@ def main(postopp_pardir, fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model. shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation. - overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions. + overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories for given task number and name. task_name(str) : Any string task name. timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far @@ -126,6 +126,7 @@ def main(postopp_pardir, # task_folder_info is a zipped lists indexed over tasks (collaborators) # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) + col_paths = setup_fl_data(postopp_pardir=postopp_pardir, three_digit_task_num=three_digit_task_num, task_name=task_name, diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 4b18168f9..c1d1a7458 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -56,8 +56,7 @@ def seed_everything(seed=1234): def train_nnunet(epochs, current_epoch, - num_train_batches_per_epoch, - num_val_batches_per_epoch, + partial_epoch=1.0, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', @@ -82,8 +81,7 @@ def train_nnunet(epochs, """ epochs (int): Number of epochs to train for on top of current epoch current_epoch (int): Which epoch will be used to grab the model - num_train_batches_per_epoch (int): Number of batches to train over each epoch (batches are sampled with replacement) - num_val_batches_per_epoch (int): Number of batches to validate on each epoch (batches are samples with replacement) + partial_epoch (float): task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" validation_only: use this if you want to only run the validation @@ -229,6 +227,14 @@ def __init__(self, **kwargs): ) trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch + + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data + # point is expected to be seen epochs number of times + + num_train_batches_per_epoch = int(partial_epoch * len(trainer.dataset_tr)/trainer.batch_size) + num_val_batches_per_epoch = int(partial_epoch * len(trainer.dataset_val)/trainer.batch_size) + + # the nnunet trainer attributes have a different naming convention than I am using trainer.num_batches_per_epoch = num_train_batches_per_epoch trainer.num_val_batches_per_epoch = num_val_batches_per_epoch diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 550f478a2..0116e4105 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -34,16 +34,14 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - num_train_batches_per_epoch, - num_val_batches_per_epoch, + partial_epoch=1.0, nnunet_task=None, config_path=None, **kwargs): """Initialize. Args: - num_train_batches_per_epoch (int) : Number of batches to be samples (with replacemtnt) for training - num_val_batches_per_epoch (int) : Number of batches to be sampled (with replacement) for validation + partial_epoch (float) : What portion of the data to use to compute number of batches per epoch (for both train and val). nnunet_task (str) : Task string used to identify the data and model folders config_path(str) : Path to the configuration file used by the training and validation script. kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). @@ -76,8 +74,7 @@ def __init__(self, **kwargs, ) - self.num_train_batches_per_epoch = num_train_batches_per_epoch - self.num_val_batches_per_epoch = num_val_batches_per_epoch + self.partial_epoch = partial_epoch self.config_path = config_path @@ -158,8 +155,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Should we put this in a separate process? train_nnunet(epochs=epochs, current_epoch=current_epoch, - num_train_batches_per_epoch = self.num_train_batches_per_epoch, - num_val_batches_per_epoch = self.num_val_batches_per_epoch, + partial_epoch=self.partial_epoch, task=self.data_loader.get_task_name()) # 3. Load metrics from checkpoint From a27ef7bb99c7663ae2ef136276458ad54cb6b033 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 2 Sep 2024 22:00:17 +0200 Subject: [PATCH 112/242] update commit hash for fl admin mlcube --- examples/fl/fl_admin/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl/fl_admin/build.sh b/examples/fl/fl_admin/build.sh index 8a5109e63..0100f8319 100644 --- a/examples/fl/fl_admin/build.sh +++ b/examples/fl/fl_admin/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 54f27c61c274f64af3d028f962f62392419cb67e + git checkout 84819a5d28abff9c196df61cb931342464c0868d docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl From bc431ffe6c3b761b28674816e6f26511e8b27042 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Mon, 2 Sep 2024 20:52:22 +0000 Subject: [PATCH 113/242] update step-ca client dockerfile --- examples/fl/cert/project/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/fl/cert/project/Dockerfile b/examples/fl/cert/project/Dockerfile index 227625bee..55ba00a59 100644 --- a/examples/fl/cert/project/Dockerfile +++ b/examples/fl/cert/project/Dockerfile @@ -1,7 +1,6 @@ FROM python:3.11.9-alpine -# update openssl to fix https://avd.aquasec.com/nvd/cve-2024-2511 -RUN apk update && apk add openssl=3.1.4-r6 jq +RUN apk update && apk add jq ARG VERSION=0.26.1 RUN wget https://dl.smallstep.com/gh-release/cli/gh-release-header/v${VERSION}/step_linux_${VERSION}_amd64.tar.gz \ From 0f7d43c380d5ed7551add17edc370235d784cbb3 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 3 Sep 2024 09:33:17 +0000 Subject: [PATCH 114/242] bugfix argument name in training.submit --- cli/medperf/commands/training/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/medperf/commands/training/training.py b/cli/medperf/commands/training/training.py index b0cb9a153..9a64bced6 100644 --- a/cli/medperf/commands/training/training.py +++ b/cli/medperf/commands/training/training.py @@ -30,7 +30,7 @@ def submit( ..., "--fl-mlcube", "-m", help="Reference Model MLCube UID" ), fl_admin_mlcube: int = typer.Option( - None, "--fl-mlcube", "-a", help="FL admin interface MLCube" + None, "--fl-admin-mlcube", "-a", help="FL admin interface MLCube" ), operational: bool = typer.Option( False, From fb4aa4f14e44d50fcccca814458e4ceeb15afd5c Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 3 Sep 2024 09:33:41 +0000 Subject: [PATCH 115/242] store status in a temporary file to avoid read-only problem --- .../commands/training/get_experiment_status.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/cli/medperf/commands/training/get_experiment_status.py b/cli/medperf/commands/training/get_experiment_status.py index 49c333117..ec0eacefc 100644 --- a/cli/medperf/commands/training/get_experiment_status.py +++ b/cli/medperf/commands/training/get_experiment_status.py @@ -3,9 +3,15 @@ from medperf.entities.ca import CA from medperf.entities.training_exp import TrainingExp from medperf.entities.cube import Cube -from medperf.utils import get_pki_assets_path, generate_tmp_path, dict_pretty_print +from medperf.utils import ( + get_pki_assets_path, + generate_tmp_path, + dict_pretty_print, + remove_path, +) from medperf.certificates import trust import yaml +import os class GetExperimentStatus: @@ -24,6 +30,7 @@ def run(cls, training_exp_id: int): execution.prepare_admin_cube() execution.get_experiment_status() execution.print_experiment_status() + execution.store_status() def __init__(self, training_exp_id: int) -> None: self.training_exp_id = training_exp_id @@ -33,7 +40,7 @@ def prepare(self): self.training_exp = TrainingExp.get(self.training_exp_id) self.ui.print(f"Training Experiment: {self.training_exp.name}") self.user_email: str = get_medperf_user_data()["email"] - self.status_output = self.training_exp.status_path + self.status_output = generate_tmp_path() self.temp_dir = generate_tmp_path() def prepare_plan(self): @@ -74,3 +81,8 @@ def print_experiment_status(self): with open(self.status_output) as f: contents = yaml.safe_load(f) dict_pretty_print(contents) + + def store_status(self): + new_status_path = self.training_exp.status_path + remove_path(new_status_path) + os.rename(self.status_output, new_status_path) From d6cd39dd752490388acfe3aae505d631ce6179d6 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 3 Sep 2024 09:34:00 +0000 Subject: [PATCH 116/242] update cli integration tests --- cli/cli_tests_training.sh | 61 ++++++++++++++++++++++++++++++++++++++- cli/tests_setup.sh | 2 ++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/cli/cli_tests_training.sh b/cli/cli_tests_training.sh index 04543d0ad..9e84ec05d 100644 --- a/cli/cli_tests_training.sh +++ b/cli/cli_tests_training.sh @@ -20,6 +20,8 @@ print_eval medperf profile create -n testdata1 checkFailed "testdata1 profile creation failed" print_eval medperf profile create -n testdata2 checkFailed "testdata2 profile creation failed" +print_eval medperf profile create -n fladmin +checkFailed "fladmin profile creation failed" ########################################################## echo "\n" @@ -71,6 +73,13 @@ checkFailed "testdata2 profile activation failed" print_eval medperf auth login -e $DATAOWNER2 checkFailed "testdata2 login failed" + +print_eval medperf profile activate fladmin +checkFailed "fladmin profile activation failed" + +print_eval medperf auth login -e $FLADMIN +checkFailed "fladmin login failed" + ########################################################## echo "\n" @@ -97,6 +106,11 @@ PREP_UID=$(medperf mlcube ls | grep trainprep | head -n 1 | tr -s ' ' | cut -d ' print_eval medperf mlcube submit --name traincube -m $TRAIN_MLCUBE -a $TRAIN_WEIGHTS --operational checkFailed "traincube submission failed" TRAINCUBE_UID=$(medperf mlcube ls | grep traincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) + +print_eval medperf mlcube submit --name fladmincube -m $FLADMIN_MLCUBE --operational +checkFailed "fladmincube submission failed" +FLADMINCUBE_UID=$(medperf mlcube ls | grep fladmincube | head -n 1 | tr -s ' ' | cut -d ' ' -f 2) + ########################################################## echo "\n" @@ -105,7 +119,7 @@ echo "\n" echo "=====================================" echo "Submit Training Experiment" echo "=====================================" -print_eval medperf training submit -n trainexp -d trainexp -p $PREP_UID -m $TRAINCUBE_UID +print_eval medperf training submit -n trainexp -d trainexp -p $PREP_UID -m $TRAINCUBE_UID -a $FLADMINCUBE_UID checkFailed "Training exp submission failed" TRAINING_UID=$(medperf training ls | grep trainexp | tail -n 1 | tr -s ' ' | cut -d ' ' -f 2) @@ -403,6 +417,51 @@ fi echo "\n" +########################################################## +echo "=====================================" +echo "Activate fladmin profile" +echo "=====================================" +print_eval medperf profile activate fladmin +checkFailed "fladmin profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Get fladmin certificate" +echo "=====================================" +print_eval medperf certificate get_client_certificate -t $TRAINING_UID +checkFailed "Get fladmin cert failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Check experiment status" +echo "=====================================" +print_eval medperf training get_experiment_status -t $TRAINING_UID +checkFailed "Get experiment status failed" + +sleep 3 # sleep some time then get status again + +print_eval medperf training get_experiment_status -t $TRAINING_UID +checkFailed "Get experiment status failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Update plan parameter" +echo "=====================================" +print_eval medperf training update_plan -t $TRAINING_UID -f "straggler_handling_policy.settings.straggler_cutoff_time" -v 1200 +checkFailed "Update plan failed" +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Waiting for other prcocesses to exit successfully" diff --git a/cli/tests_setup.sh b/cli/tests_setup.sh index ba0a9b076..5b1cdce89 100644 --- a/cli/tests_setup.sh +++ b/cli/tests_setup.sh @@ -121,6 +121,7 @@ METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" # FL cubes TRAIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/fl/mlcube/mlcube.yaml" TRAIN_WEIGHTS="https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz" +FLADMIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/bc431ffe6c3b761b28674816e6f26511e8b27042/examples/fl/fl_admin/mlcube/mlcube.yaml" # test users credentials MODELOWNER="testmo@example.com" @@ -129,6 +130,7 @@ BENCHMARKOWNER="testbo@example.com" ADMIN="testadmin@example.com" DATAOWNER2="testdo2@example.com" AGGOWNER="testao@example.com" +FLADMIN="testfladmin@example.com" # local MLCubes for local compatibility tests PREP_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/data_preparator/mlcube" From e2a4ddc1c9ce66c0fe997dc0a96d9e60b4308f26 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 3 Sep 2024 09:34:13 +0000 Subject: [PATCH 117/242] update mock tokens --- mock_tokens/generate_tokens.py | 2 +- mock_tokens/tokens.json | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mock_tokens/generate_tokens.py b/mock_tokens/generate_tokens.py index c4b6420b3..c68e74e2c 100644 --- a/mock_tokens/generate_tokens.py +++ b/mock_tokens/generate_tokens.py @@ -23,7 +23,7 @@ def token_payload(user): } -users = ["testadmin", "testbo", "testmo", "testdo", "testdo2", "testao"] +users = ["testadmin", "testbo", "testmo", "testdo", "testdo2", "testao", "testfladmin"] tokens = {} # Use headers when verifying tokens using json web keys diff --git a/mock_tokens/tokens.json b/mock_tokens/tokens.json index f4063d194..022681ff2 100644 --- a/mock_tokens/tokens.json +++ b/mock_tokens/tokens.json @@ -1,8 +1,9 @@ { - "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.ay8fatKc8rjT-Mu08QMz00D8BXmRc74M-02KZqdL8dR71CX6rD2DROSQ9wvf2sgHANcFoNWkYyr8S-Su4DqOPV87L2Jczs2tIPLVSEW28mYrR8YPysNHsSUh3eKi-7wX8F_gxpOhRdjo3Mqa_t3tw5ANfFrRVRl6SF8Mq9mOzirO9dcT3ya4WEGumBrszpJXBWJJxiNr9et1QVBKASUJVY2eclDUiK5vnokIS1nHrPL0sVos1Glcj9gtHyITmm2op7snoMuS65sjAD4dRl08XtB_amOoSZfzYmL8zqQXDYcxgX7zlJsuEsQ28Wm7XjG8tULwLK1XStexFSgL_Kp8HNUuDmWaJ2u-rCpW01Xg86VjQRuVm_eOTRDu8P6h9r0x2f5JykYrdqYw6pDUcryc8MYpceldFx1XZVv0-Fm_5LtKYCh7P9hlN-ND7soR-qeNZ8HWCOVIOYb4257SZ1rhO6Z5qDFJEvrr5aNYzNXLD82mIqTJPN6nLbIVo_EnkHIhPK9QIkfW-tGxgvxfn6zPZNBmWa4EdIRgXx-NTRbCM5SzNmqF89_jZjdIQBdB5OFSdpToFJQ8VvdCaMSPZodTHbv4JZ2sHt0Q9byYMDjVKL6lE_4nyWkVc4X72CG_2LnbkUozqPgoG0_pyeaBvK0oZKAo6YWkifCDZbzUtRh6RKk", - "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.oxbLFiBx1SUHL3QY1v7OcwZVt_WPa_4yNhNCQP4wb78d8QhIrTspkC4DJQjWSQIFtnozbcCwmuNbc-EE5E_VaABiOOUl4ayT6TvFf73oeTvNTLB7JDD8fGFGIAmvo5vwTgtNEbIpE6aBXVAsc93kcRhLEIAtpyj773sruSiLGew-nn2SvXbhW6W1_dh4u0uWBFfhYmbHix9TjxfqDR8437PA5pxr4kPObGvU7DoMm3LvUibLOOc5wu5KeVSaqSNXoZVO-a56ffEA8qdh89HjoDC7MbChDhS3kRzMTh_kVrTv1u29yi1VX7ayAfyGZt17s-R8NPEI5pkxs_gAtu8rc77wZOf8WBXVDlgqA-WZfLT_9jwgWgzl139x73Zdv5c3ptWzFdL2PdvBP4mnzAjA-53mKUqXFqeKbJUvC-P6wGC2k50A3__LOKvphHgx1p2inPaEaD1mqKStHFeb_v6PkDBNp6_654IX68Grnwm01pR67gzWnv3Y7mCZnfHXw-WA662rkPKySf1-ZOYkxw73WRlTQjYn7JL7MjpiiUKTe_zfAO7HvSE2qck2FeXn_iy4CqS49JV5Ur-bUTIr7j0rtftNpFQjmFJFBYpPtrO9r8CO58beYluKliwJkpw6YwphXEC7qvIWx89Xd3PS_A1IVBAaY2cm1cc5PRZrI1hvyh4", - "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.KEs-Zwhq9sGt92fRVIZ6k2wnZ1ij6vZ6zkUP9_ct0zHjm4LszodN5frGxF1NqI10Z-9GCfIButY1cDur78uY005gaEv63iGzETFYOsXaG7y63wzK03Dapwri8E9uPNdTXkiNCyUgd3PzRrcQsKgIBUJVCR2yL_KQacctU5Y9eXhGujQu6PzxFlBti5ajK2t-5sjOdpQOsOOQfRxVfavo1LXAsoEgZqsPnIYtWIC65wfj9gvyPqPjyMxK4jzEF2iszyCq8dq2ubs-7DsTOWTq4PpT9nvmu1h5Sl__q4edIFj8fpzNrY_r2-iKipRmfL6hxiZLGSY419Gn4iXnrD_kBjFQ7iOC7H-v_M_r2ORBm7WxObaKc_Y65Toy2swh_aKorcJzqgCqufkFU9JqTzolhMZStwuRWzT5sXE7434eNzOm7ogb_ogBB2zLyvgLpGIYUEkmQE62tYoT2USEZKI_eTbyxDOxY-WOc-aE2CjBX8K3v2gH2Se1TxVJDa67mAVH7D0XltJIVsYK14Vt989C3K9ZSThWRJciwhXGlHsOVWTL2Wkr9OQxIYEKrREBDF6cxMdz6JrEvoan8DdqYzRDSRhEmFuFUYXZNuXGa2Yz8uvjGzN4rgupOKG2-PMZUH9DgJvqU_6rY_6VGJzyqVGu9HmAlMKvRTr27kw-_UAFewg", - "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.ZBvEnVHNJtVZBBoRsuMJng7TY1MMJ67yYMpwprJBvTvjAZmltnC5_WlXdWnBcoWvnxj1rHcB1LD6pY59TWzYxtchGGmkcWqo4nVjUVSi075vfu3xgPkasXaR3PK9xdTH0aHEcyU4FsNOpnjo1YBqnYS5jMg9Rvz6MGuww1m0coED1sS7YFcKaDj-qtWYxOz85PUyxLLu9REvTdgjvFfr4HdC9bNx5YEcDgpbt7_ZKDiPhzSOFN8ed6R-wyR3gQCW9sEIFjFeO9t-gbPhsZymfELSAHIQ0JWP0hfQavCAN9XQqayrT3wQMCgiM9hTXylnLA_A2_IrpfeInASuDY-CsS0YbkF3iAVes_mxFmokC_dOWPJs2b7P-bXo6gbPpLGkUhzCYREBkfMU1wIqur5IC5PGpKXF4B3I-BSt9knD4a2DXFSkxyxOJaJh4cxbnp3GmdLWP05f9Cuu6MDv79byX7Xq03XvcRTxbTYiiFgMWxGQV3YmdClZEMmPl7t870iZn6XSsVKjvcxihJEp8i4FCyI6v3LUfeg1uIZEIBK04TntdmyW7u_3uE3wdVbTeSvoaW2GGCD5kNyhfVtJtIdyYgtIUQ7JOjHMvKdmj_keVPuBjrfD1QM4MdeYJGtJ_QcZ9LcpxWrbg3FIsA79WC3_qw4pUUdaL0ydSmB6V0mrg5E", - "testdo2@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvMkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGRvMiIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTcwOTY5NDEzMCwiZXhwIjoxMTcwOTY5NDEzMH0.SzzdojAzlU7wTucVgY4XJlffVYfdlQhuPJH0ySZH38p6wwnxYNpgHlj3RkE_5p1IIwmIRqpyBxhDO_W6PDOhE4JofwKadoEPoQ5N684wwWKnQr0NI-71gmyGj2sng1BDX0Lpi70yv6iP__OqLAR8tlcIEv7flCy3qpppIxhqBoybE3XBRMCBwrgyO3aurdAW2lZZOihorB9zUjaXlULyvLRxftQ30xosL8JeYfAWWFleHuxJfOK5X_F2vMcsU89jsDf93YMWtyDQBGhFHVXHTA8VLazc1ve5DCkXpCZU0qSo3Fg-8bhrOhZbManxdhLxU-qwvND_uAjch4OC_uPHLUBHMhkWuaa6Ift0EEvgvBS_-0LxplkMiP7pk2YCqpnjB01_1SMHgz2ubAAaTdmd3oj9JcZzRSds9-kFhTcdHA6B4Cx0ZxZPFOhdt2IPCk0D5MRN35-1ZaLwwCEi1NK2XcG7P06_HGZKUV_f6B9enkCevyIr_XDnkGvPet-9kh7y-61ee9qZ1xtruaofw9D2iNaP8V07eCmydjn4zIpk3QehdRofCDF62f-yYYxy6h1GiNH6No0ROsCLBOwLd0--TAxOGtpupLYE_Pet5CleUuQvcPDYuvMR5Hq91UPw8p-_ejW6FbvIPmQio7mEJz_YMeEdHkCJZxNLk2oLQBz8WwM", - "testao@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MDk2OTQxMzAsImV4cCI6MTE3MDk2OTQxMzB9.FfH2lvnqV7FGfUTzQNS-Y9zPA2SbO6onry_fOQtwY08nJJLvIYifMJfOe1tqUVjlud2bG_pFhbI9BYuhvRdt8Y1dpReIjcFkB3gLB1DYkzSwcXVfjdphVOZlWv0ZQTPAoh8Epu3zYtoJNbQukOGPfclMrzndNOWs4k1noZ-xtAu3j-iS3VJDJnIweYZvF_eHJum-xl3-js0mxbLssr1FSx2JZQUuYs6U-SO_gyVSmCpaNb7klMBgfYPZPO2GzN9Lxtv4INEvtg4J4nC5f_SPz8xs9efrWlmdrTsxD0h916Pv3u6hrBawcGzzS9javDlap8HOKgxMtdx9-auwsYZ1-UlcvBBqLJjGIbgAL2ncREpUHIOQIt2dWJyRQz5Xkl6uMjeW-BfAfIRM-oXGd-CObY7TuOlsUcA9VYQ9jvxp2f1bNGt0-Ib0PtnnzNhdEL6zuw1oUaEi-ST5xG_yHKvAa_xfZOncGINNPtvzh77RLVY4eWCcVnwV2OLGNYd0fxLyrm1TfGpUUY4Br_3_x9npeurY8twrkbqUuUvsXbB4TkKgSF8OnyCW-Khrg6t09UURrBYiHUa1jC2RJqMaBv-sUXzVlIU3EG30wheUNnzmctiAmh02XKEpmEE83go6XQc5h8n0hvyLbXSmaftUUBblkFwSFLzPRrg1qIti_lM_B9U" + "testadmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFkbWluQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YWRtaW4iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.A1xirQHtKErew3u9J1hVjMxiTeXbJ96ymQqnRE4xOsX8aBzaVPF7T7u6Vf0icVJT-EZ755ladQ4oaotd4iRMUMWyyOwLVSKUvgFWFLGW4GAalNV3YNVL0KjgHg_i10GXk9M9ruEgSuZlVKtY-R7oLcPpzzLwhyO5MQ0VFDrQ4ClXzw_Q6Cln4TP-oJeuIhLqRiTaqba4Yu8Vf3a6KkysJG2Ldo322669H8FiQv9OgWeOQOvvZQV44TJ9OMVQnSYbNgM7NVf5wSkSpho5rvOG2q6MQ7vWbUx7TnPyPUySE0f9M0ql4ycXbPTUIpyf4X44ynqSehtaodW6g-0cwjWJZMX7iHD4mpPaOS5_Z0nf7ARsMv903n_Ybi4GNiJqUfXazACD086Hxh7LMjxPhdLAaL76DkNH2WNp9Kb21-0VTWETcb6-bx2na0nwcVPBYNyhfJGs6gZv9pXQSm-8v345-6FBHW1xA-X00qJnVqOMo_MPDza3lju3HNN1JbMBiRY02DNofMvxUN2AlrcdmZHoaDoOrMhM4IZadnNijenH7UKIn_KCQub7Ji5HYgynpbbr63Rvjvjp2RN_qIEVyoFj71qF56J8Ccio2FIdLjigWJChfBVcEcvt9wr1LgFi3fV2KMOUVTWoRg6kTAtep_iwigKxrQWnkeUuqCy6Dju_ftU", + "testbo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGJvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0Ym8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.BhAlgy16Pz3RYvJcTKTtpwyjoPWDTGK4ExmEJIud8kz7cTFHLf-XBcHRwaAY0RaGJv7MlKnpuBF_dYXAw5wdVQPN7MKKPJRyH4LpKHrvaV3kGGmoKhfUoCrvBSwNBi8a-Y9ywnZhYzG0G4aFx2pLqmJBbjdCmpeeqXIvFlHP7xniV9HnFdT1J5iP42W1JHYZQa4cofDOyo211YtmkzjzfBa8Y7_cSMzA_1_EF-tEZCBomoS_D_ghrNsTOOzSONs1OKAHuPmhqNsS1vuQEO5vAYq9GkAac8gb7aKJ91tGWJYMFwkiDtrNErPAUkPMuuBKhsM6mHkhAc85cHgop8FcH69XkBE-a4GvF9cGv4BQh5mf3-XaRvXf49ZLozRoN6WMlVTDYcO1S8lbM9qYou4k2BaK3vztKZLGbpDhbIr-LlYjFxBYXV8aGzW9aY1cfQ-hbASOS11woYz0I9vi33ODkvAdll6K9xntNJE-9V_hHA36tdGKnpM6JqHWtUeXXyGBSnXmDaW51tEUpOXkLESZ553n-fmZoyX8auyDB_s0dlhMS4IDGgplbiLtnB2sSmrr0hK3zS-WJ8Ht5n6XVRsMvaSXsToArfAOKD6RK9KsJ_BC0ibdv1KHKotzVO_Eq4hUWlnebDCxmw9M_iFOZXCtv0zZ8_T655zrAFRqvVNfczI", + "testmo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdG1vQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0bW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.pq8dKUIH9Yw03m8DUM17YLn6Kkj3O53B3zZEUlgrMAON0jG59erfb0-vzH6r-M3QMtE1N9MYJO6idUpJYVNcPwK0CY1MFinaA4QqrzHjdbDT0sXu9LuME9hXtzoiyXmi7q3yPESUM_4Y67YBoCbe44V4wUHSP73mQFdIPud0NqhM1vON9zuBa4pTu5GT3JS5DlN6WIZwj-xQtwrYZDCCweWvQbfkct1KxAq0VdyHTM5E4EcgkFtuTk6MGAW7rBEqg72jzk_H3gmP4z5owUl_NpCOM_fWQQSH5iLGn_QKuEbtt8Pi8MB5KBCv7kId2gitIiozzyqgVGAGYMPWAroFmQCc4Pd5pBZRp7fyhE8HFvkurED4RfB0pOhKFc-l9NjV4ESYMeTjBTgdkcCTu9GFKCAEcKONjtVIMYe7Hl0UhOEbNqnDwpOR1qYxYGZTX346Z8-QmG-id9sI5FhJHxOi0p0gv6oHI1UKEZTHmNPWzGLvGbZSJWdJIdESeuhicRwcqeBOBG5tvbqIW71q1tcvAyNco9--bQYY3lLfM57uOWd60e4gLdT-GSaIFEI6jfgS9AWgP9wfujIWyqpKUvjzkZVW-mn6OFtJ53UwRYdhLFs3xwJ2MZdqOtTtidmGks4UxjUDxhcs84TwqgyuBu-gtrAmlinKKuEtKfCeQSrOtrM", + "testdo@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0ZG8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.h_MiBOCO1HpRa-t5899x4MeBaRHT56d3o-aLf7hiVjMGLaM_Qvr4HwYOTiCMdyanK-XQjZF2qpGamFXwBxcDUCqelMaEfGZkt29aTf6NFeXlsq62GF8Nukm8QU1eDTeKAuE1q8bmT7Kfz57njo-rVarZPIr4Lisy8QJmn8-zRC9pW9OQH6yrIXMgxcivPt6JiwavReevDCx4ZLyoy1ULaRLuXf_-MybcSd04Mj3zANgg_3P1LwlY1-m3kmAdx4YI20D0EuLwzu0S8F5in3rpPuhC8J1UcHG_-IMG7Y5J4g0rojZ3UwJKOBfYYzSRhP1cZqUZ7ZDUsq2rqB_03VtEflFVKtn_-xcLTy0zIbLUPQKfTIijJwtKPfh95o7vrVF4zj7tEOmJe-v0K-DN_I4xIJT0ajEoTs81tjEHHcfKNSq6mIAlROSM9lrDNLWKW_rn_0douB6kkKVUry_gFLTeGtTvsJuZWQyjw4B2WVwlhgjh3ECyRPvvhJwrZxLytBx5CzMEpC4bSl8MWvRJOh2oSOGP2xn4FtgjfAJVi0CVV_xZhTGzfhiQ8-PBMAGzgfoEkoQWPsPq-YIS_lHQw3ZTFmiehd5EtmxApHW5kvCv3bqtqCuYlWcWHAdpYmDnpETiYwCJZUTnoCsSsRGbE4TFGPhvH5J6-XFiEBbx4GKcJqU", + "testdo2@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGRvMkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vbG9jYWxob3N0OjgwMDAvIiwic3ViIjoidGVzdGRvMiIsImF1ZCI6Imh0dHBzOi8vbG9jYWxob3N0LWxvY2FsZGV2LyIsImlhdCI6MTcyNTMxNzYxNSwiZXhwIjoxMTcyNTMxNzYxNX0.q_RIdhezmevU-HxjNlNdexL43UISEHz5lyEXckpxyoK9oot8tNCHjNzWVxczJKQDLSTDYseszhouVqNvNPlOM_NRZq-bhXuAFWdpdL1ORqXCYn5RCjC4bLMdgtU_0kB0DVzWhUGYE5HD1aKaJPXkzJEyUsGafXUI5RCih5FeSzIhQjbkiVnI0oKrKKA9bbHJwa98cpw8OBsRHfnnVQPA_ZAaai12iq115HyhHUFPPgSPOKhiqu2bEMSPfYiOj31UB3UsDL2ZdU2Pxzb9UMSmuGmZUO1PkjHfozx5OTDHvcNQkQyDa9PbwprmF4SWCq67ma_OklepyUqi7dRpQXYJRl3cN1JzAlSh-eTmkhCr5SpIsg8_fMyZJ6hqwKDIpfGvUftovmwrumO5AKXoPHi7sdQnOEI31vloUx0ni0wgje7-3SBu3DcCndWZ-nyzEhex7vpEWXEkz7T8MDYxqbBh7ksLukypW8t0NDxZVXywpRaTEqn4G344UEM-L0StEaWA3G8Ed72pexWyAwsX3nF-ZvTnir6v1VvRk6H_v6dX028KgdUVlw87N3NLQ-CwWxcPvIbld2mv9djhfZ4-0cmywLs8bAy7ponsy3LsVY-gkmJUBMVV6865QohMZKKm4Ws_M__v3MQDgQzrQPCxEame7g7hwiblQar6pJYYaS_Doi4", + "testao@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGFvQGV4YW1wbGUuY29tIiwiaXNzIjoiaHR0cHM6Ly9sb2NhbGhvc3Q6ODAwMC8iLCJzdWIiOiJ0ZXN0YW8iLCJhdWQiOiJodHRwczovL2xvY2FsaG9zdC1sb2NhbGRldi8iLCJpYXQiOjE3MjUzMTc2MTUsImV4cCI6MTE3MjUzMTc2MTV9.FcKfgh6hm_sOFRE_CT8YNyoH11WHx7ANLLUA4gHi731w-Di00cU3BrOmJvOKZXqJRx1cC9HSciUpn2sxEzdL0g47RrurKRXB6Ex6jS9D267eW8_myc64WgF40yFDwK9RvRVtjFtEAuGxBz9r8sqAEDvrfkXTIya3AWm9CwBhD9I2IueSJF0uoYJ58qp0s1C8caQpbuxDh9QLQmTSqyr1lLM6nf1B3SFRm18XWCkv3GU0wH4DG94crMgUBQF-QyUN7jIsAKbI_7OaqwgFVzHY4SJTx6B_8EF8PHHE_11IfoXTI_jcl_TV0SS_e00A6RT6E7YbBNrTbhgpO02qQbFS6QpaG4TCo8hzX85Rv6DqqRxGT0z3mQzAYmZMPKf5YRIgHoLb_4rM5zYOCm3sbE6NIvKEACmHkaXjDT4TbBLNWcNAKAvL-CZpKkvmH6hdszqYCGzZbw8kSEQTMLuuOe1F0IwLVNhUWl5psAhDHSHCJ_7rI6afPZEl9Q0Aex0q9_HAX4Pb7Dc2i4jIxgIM3Ojsm2ODlFaqV9hxeaLtZrRBWE65VZhvUMk1CWgLByX-uVxPZbZAWNdTntjjoRGmXc9XT6wd9DIMdiORta_TqK0fNskgLLmM7v2H8rgcaMfbzVasG8UgKDESsk7BPhKYBavA75UbEv0zMg8Qq79eRyIGoBY", + "testfladmin@example.com": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwczovL21lZHBlcmYub3JnL2VtYWlsIjoidGVzdGZsYWRtaW5AZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2xvY2FsaG9zdDo4MDAwLyIsInN1YiI6InRlc3RmbGFkbWluIiwiYXVkIjoiaHR0cHM6Ly9sb2NhbGhvc3QtbG9jYWxkZXYvIiwiaWF0IjoxNzI1MzE3NjE1LCJleHAiOjExNzI1MzE3NjE1fQ.rFI0WUYJxtmOBIzxhlK5AUYArPl9yOcE2BmNwzAMApav5-NiGF1_L5WbatZbkbqTKDBVSvI8TrCEG191Cw8SCw-mKigRd4_C7K4HG70DDVZzStLbQUI5irChy4_a4HmA_SipUnR84jeeGNkRJCTHkeQ9WxOylKttX9ZTxbOHsCm2urMQyllEaDEe6V8M1J3JOuFtmVZRL05LCy9jJRPvTrz35o7j1mbdbjPFWe3R3SV5oXBnqtMkFjqOH93PaUgtAHGZ5TOD3sdeBxRyRNHMP7xf7LFZgih-6ai12O0Iq0wn3B5Q54-YEP5ExdmzjeCFtblQ9VzgRxG7isHxWiRytJr-vf9ScpWm9VLOhI71pCOFDg0pDLWt9L525hShv_wXJ0LwjWzU7z6gTUy2HYLGGWgh7XZTn_EqhLb7rx4DD3hLJ0KyeJ0w6UIK2Wjwr85HNQAt00HuaL2zTjyO5rF9GHdWW8SXYMPLFM3egwPJJ72dCEIWH8Hs5JjRftRREC9nGWQPWoebzb73RDtivvY3C8vjk34WjWuaaoKzeyY6PXzSRNMaUk3BVa6lxgHpri9ytQpm1LmTT-ksnndpCl5VPC0LoynJs0qdpSL7JaO73MWsgu1gt81W53leUfn-8EyhJT3x2i74HtGHyyoIJ8nzqRHrIlnwCgD5hlqlSpNFbw0" } \ No newline at end of file From dee5152ce358e9588dc9694a1bf3ca1e2a3b6679 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 3 Sep 2024 09:35:37 +0000 Subject: [PATCH 118/242] update fl testing scripts --- examples/fl/fl/build.sh | 2 +- examples/fl/fl/mlcube/workspace/training_config.yaml | 11 ++++------- examples/fl/fl/setup_test_no_docker.sh | 2 +- examples/fl/fl_admin/build.sh | 2 +- examples/fl/fl_admin/mlcube/workspace/plan.yaml | 7 +++---- examples/fl_post/fl/build.sh | 2 +- .../fl_post/fl/mlcube/workspace/training_config.yaml | 12 +++++------- examples/fl_post/fl/setup_test_no_docker.sh | 2 +- 8 files changed, 17 insertions(+), 23 deletions(-) diff --git a/examples/fl/fl/build.sh b/examples/fl/fl/build.sh index 96b4c9216..08cdbb20c 100644 --- a/examples/fl/fl/build.sh +++ b/examples/fl/fl/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 8c75ddb252930dd6306885a55d0bb9bd0462c333 + git checkout 7c9d4e7039f51014a4f7b3bedf5e2c7f1d353e68 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl/fl/mlcube/workspace/training_config.yaml b/examples/fl/fl/mlcube/workspace/training_config.yaml index 9400964d0..0b7c17aa5 100644 --- a/examples/fl/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl/fl/mlcube/workspace/training_config.yaml @@ -6,13 +6,10 @@ aggregator: last_state_path: save/classification_last.pbuf rounds_to_train: 2 write_logs: true - admins: - - admin@example.com - allowed_admin_endpoints: - - GetExperimentStatus - - AddCollaborator - - RemoveCollaborator - - SetStragglerCuttoffTime + admins_endpoints_mapping: + testfladmin@example.com: + - GetExperimentStatus + - SetStragglerCuttoffTime template: openfl.component.Aggregator assigner: settings: diff --git a/examples/fl/fl/setup_test_no_docker.sh b/examples/fl/fl/setup_test_no_docker.sh index e7625e60d..606847a54 100644 --- a/examples/fl/fl/setup_test_no_docker.sh +++ b/examples/fl/fl/setup_test_no_docker.sh @@ -124,7 +124,7 @@ rm init_weights_miccai.tar.gz cd ../../.. # for admin -ADMIN_CN="admin@example.com" +ADMIN_CN="testfladmin@example.com" mkdir ./for_admin mkdir ./for_admin/node_cert diff --git a/examples/fl/fl_admin/build.sh b/examples/fl/fl_admin/build.sh index 0100f8319..08cdbb20c 100644 --- a/examples/fl/fl_admin/build.sh +++ b/examples/fl/fl_admin/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 84819a5d28abff9c196df61cb931342464c0868d + git checkout 7c9d4e7039f51014a4f7b3bedf5e2c7f1d353e68 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl/fl_admin/mlcube/workspace/plan.yaml b/examples/fl/fl_admin/mlcube/workspace/plan.yaml index 2905f2f18..202a615ff 100644 --- a/examples/fl/fl_admin/mlcube/workspace/plan.yaml +++ b/examples/fl/fl_admin/mlcube/workspace/plan.yaml @@ -7,10 +7,9 @@ aggregator: rounds_to_train: 2 write_logs: true admins_endpoints_mapping: - admin@example.com: - - GetExperimentStatus - - AddCollaborator - - SetStragglerCutoffTime + testfladmin@example.com: + - GetExperimentStatus + - SetStragglerCuttoffTime template: openfl.component.Aggregator assigner: diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index 0100f8319..08cdbb20c 100755 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 84819a5d28abff9c196df61cb931342464c0868d + git checkout 7c9d4e7039f51014a4f7b3bedf5e2c7f1d353e68 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 43c1de6a7..d8478a921 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,13 +5,11 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : 10 - admins: - - col1@example.com - allowed_admin_endpoints: - - GetExperimentStatus - - AddCollaborator - - RemoveCollaborator + rounds_to_train : 2 + admins_endpoints_mapping: + testfladmin@example.com: + - GetExperimentStatus + - SetStragglerCuttoffTime collaborator : defaults : plan/defaults/collaborator.yaml diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index 745dffc47..91ea21ec7 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -98,7 +98,7 @@ echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml # for admin -ADMIN_CN="admin@example.com" +ADMIN_CN="testfladmin@example.com" mkdir ./for_admin mkdir ./for_admin/node_cert From 26b4337f431921cfe559c6ca10ca17998bbb3850 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Tue, 3 Sep 2024 11:12:19 +0000 Subject: [PATCH 119/242] update fl integration test mlcube example --- examples/fl/fl/project/aggregator.py | 36 +++---------------- examples/fl/fl/project/collaborator.py | 43 +++++++++++----------- examples/fl/fl/project/hooks.py | 5 ++- examples/fl/fl/project/mlcube.py | 47 ++++++++++++------------ examples/fl/fl/project/utils.py | 50 ++++++++++++++++++++++++++ examples/fl/fl/test.sh | 17 ++++++--- 6 files changed, 115 insertions(+), 83 deletions(-) diff --git a/examples/fl/fl/project/aggregator.py b/examples/fl/fl/project/aggregator.py index c0bbeafa1..8a1e7f283 100644 --- a/examples/fl/fl/project/aggregator.py +++ b/examples/fl/fl/project/aggregator.py @@ -1,39 +1,10 @@ -from utils import ( - get_aggregator_fqdn, - prepare_node_cert, - prepare_ca_cert, - prepare_plan, - prepare_cols_list, - prepare_init_weights, - create_workspace, - get_weights_path, -) - import os import shutil from subprocess import check_call from distutils.dir_util import copy_tree -def start_aggregator( - input_weights, - node_cert_folder, - ca_cert_folder, - output_logs, - output_weights, - plan_path, - collaborators, - report_path, -): - - workspace_folder = os.path.join(output_logs, "workspace") - create_workspace(workspace_folder) - prepare_plan(plan_path, workspace_folder) - prepare_cols_list(collaborators, workspace_folder) - prepare_init_weights(input_weights, workspace_folder) - fqdn = get_aggregator_fqdn(workspace_folder) - prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) - prepare_ca_cert(ca_cert_folder, workspace_folder) +def start_aggregator(workspace_folder, output_logs, output_weights, report_path): check_call(["fx", "aggregator", "start"], cwd=workspace_folder) @@ -41,7 +12,8 @@ def start_aggregator( # perhaps investigate overriding plan entries? # NOTE: logs and weights are copied, even if target folders are not empty - copy_tree(os.path.join(workspace_folder, "logs"), output_logs) + if os.path.exists(os.path.join(workspace_folder, "logs")): + copy_tree(os.path.join(workspace_folder, "logs"), output_logs) # NOTE: conversion fails since openfl needs sample data... # weights_paths = get_weights_path(fl_workspace) @@ -56,5 +28,5 @@ def start_aggregator( # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) - with open(report_path, 'w') as f: + with open(report_path, "w") as f: f.write("IsDone: 1") diff --git a/examples/fl/fl/project/collaborator.py b/examples/fl/fl/project/collaborator.py index 38c5048b6..fb4cdd1c2 100644 --- a/examples/fl/fl/project/collaborator.py +++ b/examples/fl/fl/project/collaborator.py @@ -1,32 +1,29 @@ -from utils import ( - get_collaborator_cn, - prepare_node_cert, - prepare_ca_cert, - prepare_plan, - create_workspace, -) import os +from utils import get_collaborator_cn import shutil from subprocess import check_call -def start_collaborator( - data_path, - labels_path, - node_cert_folder, - ca_cert_folder, - plan_path, - output_logs, -): - workspace_folder = os.path.join(output_logs, "workspace") - create_workspace(workspace_folder) - prepare_plan(plan_path, workspace_folder) +def start_collaborator(workspace_folder): cn = get_collaborator_cn() - prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) - prepare_ca_cert(ca_cert_folder, workspace_folder) - - # set log files - check_call(["fx", "collaborator", "start", "-n", cn], cwd=workspace_folder) + check_call( + [os.environ.get("OPENFL_EXECUTABLE", "fx"), "collaborator", "start", "-n", cn], + cwd=workspace_folder, + ) # Cleanup shutil.rmtree(workspace_folder, ignore_errors=True) + + +def check_connectivity(workspace_folder): + cn = get_collaborator_cn() + check_call( + [ + os.environ.get("OPENFL_EXECUTABLE", "fx"), + "collaborator", + "connectivity_check", + "-n", + cn, + ], + cwd=workspace_folder, + ) diff --git a/examples/fl/fl/project/hooks.py b/examples/fl/fl/project/hooks.py index dd3960ba4..9dc59582f 100644 --- a/examples/fl/fl/project/hooks.py +++ b/examples/fl/fl/project/hooks.py @@ -30,9 +30,9 @@ def collaborator_pre_training_hook( ca_cert_folder, plan_path, output_logs, + workspace_folder, ): cn = get_collaborator_cn() - workspace_folder = os.path.join(output_logs, "workspace") target_data_folder = os.path.join(workspace_folder, "data", cn) os.makedirs(target_data_folder, exist_ok=True) @@ -69,6 +69,7 @@ def collaborator_post_training_hook( ca_cert_folder, plan_path, output_logs, + workspace_folder, ): pass @@ -82,6 +83,7 @@ def aggregator_pre_training_hook( plan_path, collaborators, report_path, + workspace_folder, ): pass @@ -95,5 +97,6 @@ def aggregator_post_training_hook( plan_path, collaborators, report_path, + workspace_folder, ): pass diff --git a/examples/fl/fl/project/mlcube.py b/examples/fl/fl/project/mlcube.py index 9e4a7e728..064440e95 100644 --- a/examples/fl/fl/project/mlcube.py +++ b/examples/fl/fl/project/mlcube.py @@ -1,9 +1,7 @@ """MLCube handler file""" -import os -import shutil import typer -from collaborator import start_collaborator +from collaborator import start_collaborator, check_connectivity from aggregator import start_aggregator from plan import generate_plan from hooks import ( @@ -12,22 +10,11 @@ collaborator_pre_training_hook, collaborator_post_training_hook, ) +from utils import generic_setup, generic_teardown, setup_collaborator, setup_aggregator app = typer.Typer() -def _setup(output_logs): - tmp_folder = os.path.join(output_logs, ".tmp") - os.makedirs(tmp_folder, exist_ok=True) - # TODO: this should be set before any code imports tempfile - os.environ["TMPDIR"] = tmp_folder - - -def _teardown(output_logs): - tmp_folder = os.path.join(output_logs, ".tmp") - shutil.rmtree(tmp_folder, ignore_errors=True) - - @app.command("train") def train( data_path: str = typer.Option(..., "--data_path"), @@ -37,23 +24,27 @@ def train( plan_path: str = typer.Option(..., "--plan_path"), output_logs: str = typer.Option(..., "--output_logs"), ): - _setup(output_logs) - collaborator_pre_training_hook( + workspace_folder = generic_setup(output_logs) + setup_collaborator( data_path=data_path, labels_path=labels_path, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, plan_path=plan_path, output_logs=output_logs, + workspace_folder=workspace_folder, ) - start_collaborator( + check_connectivity(workspace_folder) + collaborator_pre_training_hook( data_path=data_path, labels_path=labels_path, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, plan_path=plan_path, output_logs=output_logs, + workspace_folder=workspace_folder, ) + start_collaborator(workspace_folder=workspace_folder) collaborator_post_training_hook( data_path=data_path, labels_path=labels_path, @@ -61,8 +52,9 @@ def train( ca_cert_folder=ca_cert_folder, plan_path=plan_path, output_logs=output_logs, + workspace_folder=workspace_folder, ) - _teardown(output_logs) + generic_teardown(output_logs) @app.command("start_aggregator") @@ -76,8 +68,8 @@ def start_aggregator_( collaborators: str = typer.Option(..., "--collaborators"), report_path: str = typer.Option(..., "--report_path"), ): - _setup(output_logs) - aggregator_pre_training_hook( + workspace_folder = generic_setup(output_logs) + setup_aggregator( input_weights=input_weights, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, @@ -86,8 +78,9 @@ def start_aggregator_( plan_path=plan_path, collaborators=collaborators, report_path=report_path, + workspace_folder=workspace_folder, ) - start_aggregator( + aggregator_pre_training_hook( input_weights=input_weights, node_cert_folder=node_cert_folder, ca_cert_folder=ca_cert_folder, @@ -96,6 +89,13 @@ def start_aggregator_( plan_path=plan_path, collaborators=collaborators, report_path=report_path, + workspace_folder=workspace_folder, + ) + start_aggregator( + workspace_folder=workspace_folder, + output_logs=output_logs, + output_weights=output_weights, + report_path=report_path, ) aggregator_post_training_hook( input_weights=input_weights, @@ -106,8 +106,9 @@ def start_aggregator_( plan_path=plan_path, collaborators=collaborators, report_path=report_path, + workspace_folder=workspace_folder, ) - _teardown(output_logs) + generic_teardown(output_logs) @app.command("generate_plan") diff --git a/examples/fl/fl/project/utils.py b/examples/fl/fl/project/utils.py index d92add606..a0da69a16 100644 --- a/examples/fl/fl/project/utils.py +++ b/examples/fl/fl/project/utils.py @@ -3,6 +3,56 @@ import shutil +def generic_setup(output_logs): + tmpfolder = os.path.join(output_logs, ".tmp") + os.makedirs(tmpfolder, exist_ok=True) + # NOTE: this should be set before any code imports tempfile + os.environ["TMPDIR"] = tmpfolder + workspace_folder = os.path.join(output_logs, "workspace") + os.makedirs(workspace_folder, exist_ok=True) + create_workspace(workspace_folder) + return workspace_folder + + +def setup_collaborator( + data_path, + labels_path, + node_cert_folder, + ca_cert_folder, + plan_path, + output_logs, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + cn = get_collaborator_cn() + prepare_node_cert(node_cert_folder, "client", f"col_{cn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def setup_aggregator( + input_weights, + node_cert_folder, + ca_cert_folder, + output_logs, + output_weights, + plan_path, + collaborators, + report_path, + workspace_folder, +): + prepare_plan(plan_path, workspace_folder) + prepare_cols_list(collaborators, workspace_folder) + prepare_init_weights(input_weights, workspace_folder) + fqdn = get_aggregator_fqdn(workspace_folder) + prepare_node_cert(node_cert_folder, "server", f"agg_{fqdn}", workspace_folder) + prepare_ca_cert(ca_cert_folder, workspace_folder) + + +def generic_teardown(output_logs): + tmp_folder = os.path.join(output_logs, ".tmp") + shutil.rmtree(tmp_folder, ignore_errors=True) + + def create_workspace(fl_workspace): plan_folder = os.path.join(fl_workspace, "plan") workspace_config = os.path.join(fl_workspace, ".workspace") diff --git a/examples/fl/fl/test.sh b/examples/fl/fl/test.sh index 95bd5b673..ae856d794 100644 --- a/examples/fl/fl/test.sh +++ b/examples/fl/fl/test.sh @@ -13,10 +13,19 @@ COL1="medperf mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIP COL2="medperf mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" COL3="medperf mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" -gnome-terminal -- bash -c "$AGG; bash" -gnome-terminal -- bash -c "$COL1; bash" -gnome-terminal -- bash -c "$COL2; bash" -gnome-terminal -- bash -c "$COL3; bash" +# gnome-terminal -- bash -c "$AGG; bash" +# gnome-terminal -- bash -c "$COL1; bash" +# gnome-terminal -- bash -c "$COL2; bash" +# gnome-terminal -- bash -c "$COL3; bash" +rm agg.log col1.log col2.log col3.log +$AGG >>agg.log & +sleep 6 +$COL1 >>col1.log & +sleep 6 +$COL2 >>col2.log & +sleep 6 +$COL3 >>col3.log & +wait # docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 # python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 From 1868af95838169e1525128762e61ed8ca7cf2466 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 4 Sep 2024 10:38:19 -0700 Subject: [PATCH 120/242] recent changes --- examples/fl_post/fl/project/Dockerfile | 6 ++++-- examples/fl_post/fl/project/hooks.py | 2 ++ examples/fl_post/fl/project/mlcube.py | 2 ++ .../fl_post/fl/project/nnunet_data_setup.py | 17 +++++++++-------- .../fl_post/fl/project/nnunet_model_setup.py | 5 ++++- examples/fl_post/fl/project/src/nnunet_v1.py | 10 +++++++--- examples/fl_post/fl/test.sh | 5 ++++- 7 files changed, 32 insertions(+), 15 deletions(-) diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile index d12baa7bb..c5e1ef2ee 100644 --- a/examples/fl_post/fl/project/Dockerfile +++ b/examples/fl_post/fl/project/Dockerfile @@ -2,7 +2,9 @@ FROM local/openfl:local ENV LANG C.UTF-8 ENV CUDA_VISIBLE_DEVICES="0" - +# ENV http_proxy="http://proxy-us.intel.com:912" +# ENV https_proxy="http://proxy-us.intel.com:912" +ENV no_proxy=localhost,spr-gpu01.jf.intel.com # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ @@ -22,4 +24,4 @@ RUN /cuda118/bin/pip install --no-cache-dir -r /mlcube_project/requirements.txt # Copy mlcube project folder COPY . /mlcube_project -ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] \ No newline at end of file +ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py index a079bd586..ff1cda9ee 100644 --- a/examples/fl_post/fl/project/hooks.py +++ b/examples/fl_post/fl/project/hooks.py @@ -54,6 +54,8 @@ def collaborator_pre_training_hook( # when evan runs, init_model_path, init_model_info_path should be None # plans_path should also be None (the returned thing will point to where it lives so that it will be synced with others) + print(f"Brandon DEBUG - postopp_pardir will be pointed to: {workspace_folder} which has data subfolder containing: {os.listdir(os.path.join(workspace_folder, 'data'))}") + nnunet_setup.main(postopp_pardir=workspace_folder, three_digit_task_num=537, # FIXME: does this need to be set in any particular way? init_model_path=f'{init_nnunet_directory}/model_initial_checkpoint.model', diff --git a/examples/fl_post/fl/project/mlcube.py b/examples/fl_post/fl/project/mlcube.py index ba0140395..d88c4886a 100644 --- a/examples/fl_post/fl/project/mlcube.py +++ b/examples/fl_post/fl/project/mlcube.py @@ -49,6 +49,7 @@ def train( output_logs=output_logs, init_nnunet_directory=init_nnunet_directory, ) + start_collaborator( data_path=data_path, labels_path=labels_path, @@ -65,6 +66,7 @@ def train( plan_path=plan_path, output_logs=output_logs, ) + _teardown(output_logs) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 3f24c9515..dcce22e8d 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -246,11 +246,8 @@ def setup_fl_data(postopp_pardir, should be run using a virtual environment that has nnunet version 1 installed. args: - postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be - equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, - the samples within that single directory will be used to create num_insititutions shards. - If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. - Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: + postopp_pardir(str) : Parent directory for postopp data. + This directory should have 'data' and 'labels' subdirectories, with structure: ├── data │ ├── AAAC_0 │ │ ├── 2008.03.30 @@ -298,7 +295,7 @@ def setup_fl_data(postopp_pardir, │ └── AAAC_extra_2008.12.10_final_seg.nii.gz └── report.yaml - three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. + three_digit_task_num(str): Should start with '5'. task_name(str) : Any string task name. percent_train(float) : what percent of data is put into the training data split (rest to val) split_logic(str) : Determines how train/val split is performed @@ -337,6 +334,7 @@ def setup_fl_data(postopp_pardir, subject_to_timestamps = {} print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n") + print(f"\nBrandon DEBUG -- Here are all subjects: {all_subjects}\n\n") for postopp_subject_dir in all_subjects: subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, postopp_data_dirpath=postopp_data_dirpath, @@ -357,13 +355,16 @@ def setup_fl_data(postopp_pardir, print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") if plans_path: subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp", "-pl2d", "None"]) plans_identifier_for_model_writing = shared_plans_identifier else: # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"]) plans_identifier_for_model_writing = local_plans_identifier + # Brandon debug + # print(f"\nListing directory of plans path: {os.listdir(os.path.join(os.environ['nnUNet_preprocessed'], 'Task537_FLPost'))}\n\n") + # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val write_splits_file(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index cd9412790..5f31bfec6 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -67,9 +67,12 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one """ - # Remove 2D data and 2D data info if appropriate + # Removing 2D data is not longer needed since we set "-pl2d None during plan and preprocessing call" + # TODO: remove this comment once tested + """ if network != '2d': delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) + """ # get or create architecture info diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index c1d1a7458..e5b6a073e 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -228,6 +228,13 @@ def __init__(self, **kwargs): trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch + print(f"\n\nBrandon DEBUG - dataset directory is: {dataset_directory} \n") + print(f"\n\nBrandon DEBUG - dataset directory contains: {os.listdir(dataset_directory)} \n") + print(f"\n\nBrandon DEBUG - plans file variable has value: {plans_file} \n") + + # TODO: call validation separately + trainer.initialize(not validation_only) + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times @@ -238,9 +245,6 @@ def __init__(self, **kwargs): trainer.num_batches_per_epoch = num_train_batches_per_epoch trainer.num_val_batches_per_epoch = num_val_batches_per_epoch - # TODO: call validation separately - trainer.initialize(not validation_only) - if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": trainer.save_checkpoint( join(trainer.output_folder, "model_final_checkpoint.model") diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 9463c56a1..d70d35c3e 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -1,3 +1,6 @@ +export HTTPS_PROXY= +export http_proxy= + # generate plan and copy it to each node GENERATE_PLAN_PLATFORM="docker" AGG_PLATFORM="docker" @@ -16,7 +19,7 @@ cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf --platform $COL2_PLATFORM --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL2="medperf --platform $COL2_PLATFORM --gpus=device=2 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" # medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & From 89d2ca26aeb9d99280654835fdf3a9c9f7996aef Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 9 Sep 2024 13:19:04 -0700 Subject: [PATCH 121/242] initial changes to support timeouts rather than partial_epoch values --- .../fl_post/fl/project/nnunet_data_setup.py | 9 ++--- examples/fl_post/fl/project/src/nnunet_v1.py | 34 +++++++++++++++---- .../fl_post/fl/project/src/runner_nnunetv1.py | 20 ++++++++--- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index dcce22e8d..db3894e9d 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -333,8 +333,6 @@ def setup_fl_data(postopp_pardir, # Track the subjects and timestamps for each shard subject_to_timestamps = {} - print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n") - print(f"\nBrandon DEBUG -- Here are all subjects: {all_subjects}\n\n") for postopp_subject_dir in all_subjects: subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, postopp_data_dirpath=postopp_data_dirpath, @@ -354,17 +352,14 @@ def setup_fl_data(postopp_pardir, # Now call the os process to preprocess the data print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") if plans_path: - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp", "-pl2d", "None"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl2d", "None", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-pl2d", "None", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) plans_identifier_for_model_writing = shared_plans_identifier else: # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"]) plans_identifier_for_model_writing = local_plans_identifier - # Brandon debug - # print(f"\nListing directory of plans path: {os.listdir(os.path.join(os.environ['nnUNet_preprocessed'], 'Task537_FLPost'))}\n\n") - # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val write_splits_file(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index e5b6a073e..25abc059c 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -56,7 +56,9 @@ def seed_everything(seed=1234): def train_nnunet(epochs, current_epoch, - partial_epoch=1.0, + train_val_cutoff=None, + train_cutoff_part=None, + val_cutoff_part=None, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', @@ -81,7 +83,9 @@ def train_nnunet(epochs, """ epochs (int): Number of epochs to train for on top of current epoch current_epoch (int): Which epoch will be used to grab the model - partial_epoch (float): + train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. + train_cutoff_part (float): Portion of train_val_cutoff going to training + val_cutoff_part (float): Portion of train_val_cutoff going to val task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" validation_only: use this if you want to only run the validation @@ -132,6 +136,14 @@ def __init__(self, **kwargs): if args.deterministic: seed_everything() + # validation of some args + + if args.train_val_cutoff or args.train_cutoff_part or args.val_cutoff_part: + if not (args.train_val_cutoff and args.train_cutoff_part and args.val_cutoff_part): + raise ValueError(f"If any of train_val_cutoff, train_cutoff_part, or val_cutoff_part are None, then they all must be None.") + if args.train_cutoff_part + args.val_cutoff_part >= 1.0: + raise ValueError(f"train_cutoff_part + val_cutoff_part must be less than 1.0 to account for some time left outside of those two loops.") + task = args.task fold = args.fold network = args.network @@ -238,8 +250,16 @@ def __init__(self, **kwargs): # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times - num_train_batches_per_epoch = int(partial_epoch * len(trainer.dataset_tr)/trainer.batch_size) - num_val_batches_per_epoch = int(partial_epoch * len(trainer.dataset_val)/trainer.batch_size) + num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) + num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + + train_cutoff = int(np.floor(train_cutoff_part * train_val_cutoff)) + val_cutoff = int(np.floor(val_cutoff_part * train_val_cutoff)) + + if train_cutoff == 0: + raise ValueError(f"The setting for train_cutoff_part does not allow (with use of np.floor) for a non-zero train loop time in seconds.") + if val_cutoff == 0: + raise ValueError(f"The setting for val_cutoff_part does not allow (with use of np.floor) for a non-zero val loop time in seconds.") # the nnunet trainer attributes have a different naming convention than I am using trainer.num_batches_per_epoch = num_train_batches_per_epoch @@ -266,15 +286,17 @@ def __init__(self, **kwargs): # new training without pretraine weights, do nothing pass - trainer.run_training() + batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff) else: # if valbest: # trainer.load_best_checkpoint(train=False) # else: # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() + + return batches_applied_train, batches_applied_val - trainer.network.eval() + # trainer.network.eval() # if fold == "all": # print("--> fold == 'all'") diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 0116e4105..7bf705bd4 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -34,14 +34,20 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - partial_epoch=1.0, + train_val_cutoff=None, + train_cutoff_part=None, + val_cutoff_part=None, + other_cutoff_part=None, nnunet_task=None, config_path=None, **kwargs): """Initialize. Args: - partial_epoch (float) : What portion of the data to use to compute number of batches per epoch (for both train and val). + train_val_cutoff (int) : Total time (in seconds) limit to use in approximating a restriction to training and validation activities. + train_cutoff_part (float) : Portion of train_val_cutoff going to training + val_cutoff_part (float) : Portion of train_val_cutoff going to val + other_cutoff_part (float) : Portion of train_val_cutoff going to the rest of the 'train' function nnunet_task (str) : Task string used to identify the data and model folders config_path(str) : Path to the configuration file used by the training and validation script. kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). @@ -74,7 +80,10 @@ def __init__(self, **kwargs, ) - self.partial_epoch = partial_epoch + self.train_val_cutoff = train_val_cutoff + self.train_cutoff_part = train_cutoff_part + self.val_cutoff_part = val_cutoff_part + self.other_cutoff_part = other_cutoff_part self.config_path = config_path @@ -155,7 +164,10 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Should we put this in a separate process? train_nnunet(epochs=epochs, current_epoch=current_epoch, - partial_epoch=self.partial_epoch, + train_val_cutoff=self.train_val_cutoff, + train_cutoff_part = self.train_cutoff_part, + val_cutoff_part = self.val_cutoff_part, + other_cutoff_part = self.other_cutoff_part, task=self.data_loader.get_task_name()) # 3. Load metrics from checkpoint From a49f892cb29bd4ad3a1f2ab6e9eee9ff47b312ad Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 13 Sep 2024 17:05:37 -0700 Subject: [PATCH 122/242] now tracking time in train and val loop directly --- examples/fl_post/fl/project/src/nnunet_v1.py | 35 ++++-------------- .../fl_post/fl/project/src/runner_nnunetv1.py | 37 +++++++++---------- 2 files changed, 26 insertions(+), 46 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 25abc059c..64c43e7d6 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -253,8 +253,8 @@ def __init__(self, **kwargs): num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) - train_cutoff = int(np.floor(train_cutoff_part * train_val_cutoff)) - val_cutoff = int(np.floor(val_cutoff_part * train_val_cutoff)) + train_cutoff = train_cutoff_part * train_val_cutoff + val_cutoff = val_cutoff_part * train_val_cutoff if train_cutoff == 0: raise ValueError(f"The setting for train_cutoff_part does not allow (with use of np.floor) for a non-zero train loop time in seconds.") @@ -293,29 +293,10 @@ def __init__(self, **kwargs): # else: # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() + + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = batches_applied_val / float(num_val_batches_per_epoch) + + return train_completed, val_completed + - return batches_applied_train, batches_applied_val - - # trainer.network.eval() - - # if fold == "all": - # print("--> fold == 'all'") - # print("--> DONE") - # else: - # # predict validation - # trainer.validate( - # save_softmax=args.npz, - # validation_folder_name=val_folder, - # run_postprocessing_on_folds=not disable_postprocessing_on_folds, - # overwrite=args.val_disable_overwrite, - # ) - - # if network == "3d_lowres" and not args.disable_next_stage_pred: - # print("predicting segmentations for the next stage of the cascade") - # predict_next_stage( - # trainer, - # join( - # dataset_directory, - # trainer.plans["data_identifier"] + "_stage%d" % 1, - # ), - # ) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 7bf705bd4..3a38493ce 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -34,20 +34,16 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_val_cutoff=None, - train_cutoff_part=None, - val_cutoff_part=None, - other_cutoff_part=None, + train_cutoff=16, + val_cutoff=2, nnunet_task=None, config_path=None, **kwargs): """Initialize. Args: - train_val_cutoff (int) : Total time (in seconds) limit to use in approximating a restriction to training and validation activities. - train_cutoff_part (float) : Portion of train_val_cutoff going to training - val_cutoff_part (float) : Portion of train_val_cutoff going to val - other_cutoff_part (float) : Portion of train_val_cutoff going to the rest of the 'train' function + train_cutoff (int) : Total time (in seconds) allowed for iterating over train batches (plus or minus one iteration since check willl be once an iteration). + val_cutoff (int) : Total time (in seconds) allowed for iterating over val batches (plus or minus one iteration since check willl be once an iteration). nnunet_task (str) : Task string used to identify the data and model folders config_path(str) : Path to the configuration file used by the training and validation script. kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). @@ -80,10 +76,8 @@ def __init__(self, **kwargs, ) - self.train_val_cutoff = train_val_cutoff - self.train_cutoff_part = train_cutoff_part - self.val_cutoff_part = val_cutoff_part - self.other_cutoff_part = other_cutoff_part + self.train_cutoff = train_cutoff + self.val_cutoff = val_cutoff self.config_path = config_path @@ -162,13 +156,15 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - train_nnunet(epochs=epochs, - current_epoch=current_epoch, - train_val_cutoff=self.train_val_cutoff, - train_cutoff_part = self.train_cutoff_part, - val_cutoff_part = self.val_cutoff_part, - other_cutoff_part = self.other_cutoff_part, - task=self.data_loader.get_task_name()) + train_completed, val_completed = train_nnunet(epochs=epochs, + current_epoch=current_epoch, + train_cutoff=self.train_cutoff, + val_cutoff = self.val_cutoff, + task=self.data_loader.get_task_name()) + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed)*100}% of the val work.") + + # 3. Load metrics from checkpoint (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] @@ -176,6 +172,9 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): metrics = {'train_loss': all_tr_losses[-1], 'val_eval': all_val_eval_metrics[-1]} + ###################################################################################################### + # TODO: Provide train_completed and val_completed to be incorporated into the collab weight computation + ###################################################################################################### return self.convert_results_to_tensorkeys(col_name, round_num, metrics) From b38d9e4aba4b86e1b99abe804aa6743ae4ab2d38 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 13 Sep 2024 17:17:32 -0700 Subject: [PATCH 123/242] moving new param change to the nnunet train function --- examples/fl_post/fl/project/src/nnunet_v1.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 64c43e7d6..261c5431e 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -56,9 +56,8 @@ def seed_everything(seed=1234): def train_nnunet(epochs, current_epoch, - train_val_cutoff=None, - train_cutoff_part=None, - val_cutoff_part=None, + traincutoff=np.inf, + val_cutoff=np.inf, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', @@ -240,10 +239,6 @@ def __init__(self, **kwargs): trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch - print(f"\n\nBrandon DEBUG - dataset directory is: {dataset_directory} \n") - print(f"\n\nBrandon DEBUG - dataset directory contains: {os.listdir(dataset_directory)} \n") - print(f"\n\nBrandon DEBUG - plans file variable has value: {plans_file} \n") - # TODO: call validation separately trainer.initialize(not validation_only) @@ -253,14 +248,6 @@ def __init__(self, **kwargs): num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) - train_cutoff = train_cutoff_part * train_val_cutoff - val_cutoff = val_cutoff_part * train_val_cutoff - - if train_cutoff == 0: - raise ValueError(f"The setting for train_cutoff_part does not allow (with use of np.floor) for a non-zero train loop time in seconds.") - if val_cutoff == 0: - raise ValueError(f"The setting for val_cutoff_part does not allow (with use of np.floor) for a non-zero val loop time in seconds.") - # the nnunet trainer attributes have a different naming convention than I am using trainer.num_batches_per_epoch = num_train_batches_per_epoch trainer.num_val_batches_per_epoch = num_val_batches_per_epoch From 5c34957e93edf222e0a7744e60401a16a4dfd460 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 13 Sep 2024 17:24:00 -0700 Subject: [PATCH 124/242] had a typo in param and also unsued args check --- examples/fl_post/fl/project/src/nnunet_v1.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 261c5431e..1448485d8 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -56,7 +56,7 @@ def seed_everything(seed=1234): def train_nnunet(epochs, current_epoch, - traincutoff=np.inf, + train_cutoff=np.inf, val_cutoff=np.inf, network='3d_fullres', network_trainer='nnUNetTrainerV2', @@ -135,14 +135,6 @@ def __init__(self, **kwargs): if args.deterministic: seed_everything() - # validation of some args - - if args.train_val_cutoff or args.train_cutoff_part or args.val_cutoff_part: - if not (args.train_val_cutoff and args.train_cutoff_part and args.val_cutoff_part): - raise ValueError(f"If any of train_val_cutoff, train_cutoff_part, or val_cutoff_part are None, then they all must be None.") - if args.train_cutoff_part + args.val_cutoff_part >= 1.0: - raise ValueError(f"train_cutoff_part + val_cutoff_part must be less than 1.0 to account for some time left outside of those two loops.") - task = args.task fold = args.fold network = args.network From 0ccf964b18d2f6a4e27b4b75ae38d16d0e413205 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 16 Sep 2024 12:04:23 -0700 Subject: [PATCH 125/242] some clean up --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 3a38493ce..e074c1424 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -6,7 +6,6 @@ """ # TODO: Clean up imports -# TODO: ask Micah if this has to be changed (most probably no) import os import subprocess @@ -119,7 +118,6 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): epoch = checkpoint_dict['epoch'] new_state = {} # grabbing keys from the checkpoint state dict, poping from the tensor_dict - # Brandon DEBUGGING seen_keys = [] for k in checkpoint_dict['state_dict']: if k not in seen_keys: From 07ef45150240736754cdf075da80ff7883dcb658 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 19 Sep 2024 11:05:51 -0700 Subject: [PATCH 126/242] corrected percent completed calculation --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index e074c1424..9979930c6 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -34,7 +34,7 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): def __init__(self, train_cutoff=16, - val_cutoff=2, + val_cutoff=1, nnunet_task=None, config_path=None, **kwargs): @@ -160,7 +160,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name()) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed)*100}% of the val work.") + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") From 0b412ec94c9dac0d9ece5b14c014f81e01b37767 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 25 Sep 2024 13:03:39 -0700 Subject: [PATCH 127/242] supporting max_num_epochs as nnunet runner init parameter --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 +++- examples/fl_post/fl/project/src/runner_nnunetv1.py | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 1448485d8..255dfb71b 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -54,7 +54,8 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def train_nnunet(epochs, +def train_nnunet(max_num_epochs, + epochs, current_epoch, train_cutoff=np.inf, val_cutoff=np.inf, @@ -205,6 +206,7 @@ def __init__(self, **kwargs): trainer = trainer_class( plans_file, fold, + max_num_epochs=max_num_epochs, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 9979930c6..96e238df6 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -33,10 +33,11 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_cutoff=16, - val_cutoff=1, + train_cutoff=160, + val_cutoff=10, nnunet_task=None, config_path=None, + max_num_epochs=2, **kwargs): """Initialize. @@ -78,6 +79,7 @@ def __init__(self, self.train_cutoff = train_cutoff self.val_cutoff = val_cutoff self.config_path = config_path + self.max_num_epochs=max_num_epochs def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): @@ -154,7 +156,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - train_completed, val_completed = train_nnunet(epochs=epochs, + train_completed, val_completed = train_nnunet(max_num_epochs=self.max_num_epochs, + epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, val_cutoff = self.val_cutoff, From 81c552b400a059f62a584c15d27283833f1feb32 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 25 Sep 2024 16:28:41 -0700 Subject: [PATCH 128/242] changing previous commit using max_num_epochs to TOTAL_max_num_epochs --- examples/fl_post/fl/project/src/nnunet_v1.py | 7 ++++--- examples/fl_post/fl/project/src/runner_nnunetv1.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 255dfb71b..7e669254a 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -54,7 +54,7 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def train_nnunet(max_num_epochs, +def train_nnunet(TOTAL_max_num_epochs, epochs, current_epoch, train_cutoff=np.inf, @@ -81,7 +81,8 @@ def train_nnunet(max_num_epochs, pretrained_weights=None): """ - epochs (int): Number of epochs to train for on top of current epoch + TOTAL_max_num_epochs (int): Provides the total number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during the course of federated training) + epochs (int): Number of epochs to trainon top of current epoch current_epoch (int): Which epoch will be used to grab the model train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training @@ -206,7 +207,7 @@ def __init__(self, **kwargs): trainer = trainer_class( plans_file, fold, - max_num_epochs=max_num_epochs, + TOTAL_max_num_epochs=TOTAL_max_num_epochs, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 96e238df6..068dded4f 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -37,7 +37,7 @@ def __init__(self, val_cutoff=10, nnunet_task=None, config_path=None, - max_num_epochs=2, + TOTAL_max_num_epochs=2, **kwargs): """Initialize. @@ -46,6 +46,7 @@ def __init__(self, val_cutoff (int) : Total time (in seconds) allowed for iterating over val batches (plus or minus one iteration since check willl be once an iteration). nnunet_task (str) : Task string used to identify the data and model folders config_path(str) : Path to the configuration file used by the training and validation script. + TOTAL_max_num_epochs (int) : Total number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). TODO: """ @@ -79,7 +80,7 @@ def __init__(self, self.train_cutoff = train_cutoff self.val_cutoff = val_cutoff self.config_path = config_path - self.max_num_epochs=max_num_epochs + self.TOTAL_max_num_epochs=TOTAL_max_num_epochs def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): @@ -156,7 +157,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - train_completed, val_completed = train_nnunet(max_num_epochs=self.max_num_epochs, + train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, From eea4279c61f19e111a74a198f05813b279e7c578 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 26 Sep 2024 10:05:44 -0700 Subject: [PATCH 129/242] putting timeout check at beginning of train and val loops so that we can perform train or val independently using the timouts --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 068dded4f..ee4df59c7 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -33,11 +33,11 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_cutoff=160, - val_cutoff=10, + train_cutoff=np.inf, + val_cutoff=np.inf, nnunet_task=None, config_path=None, - TOTAL_max_num_epochs=2, + TOTAL_max_num_epochs=1000, **kwargs): """Initialize. From 4e303d12c8199cff3b47be742b4d88ccd67a8c0d Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 26 Sep 2024 10:37:02 -0700 Subject: [PATCH 130/242] first pass at accounting for separate train and val tasks into config --- .../fl/mlcube/workspace/training_config.yaml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 43c1de6a7..793383cbf 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,7 +5,7 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : 10 + rounds_to_train : 300 admins: - col1@example.com allowed_admin_endpoints: @@ -47,17 +47,28 @@ assigner : - name : train_and_validate percentage : 1.0 tasks : - # - aggregated_model_validation + - aggregated_model_validation - train - # - locally_tuned_model_validation + - locally_tuned_model_validation tasks : defaults : plan/defaults/tasks_torch.yaml + aggregated_model_validation: + function : validate + kwargs : + metrics : + - val_eval + epochs : 1 train: function : train kwargs : metrics : - train_loss + epochs : 1 + locally_tuned_model_validation: + function : validate + kwargs : + metrics : - val_eval epochs : 1 From b7970cfe4bd8f33f73244c7df98dca0a0b8759c0 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 12:08:45 -0700 Subject: [PATCH 131/242] allowing val_cutoff to pass through to nnunet training function --- .../fl_post/fl/project/src/runner_nnunetv1.py | 109 +++++++----------- 1 file changed, 44 insertions(+), 65 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index ee4df59c7..11c513079 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -149,10 +149,9 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - # 2. Train function existing externally + # 2. Train/val function existing externally # Some todo inside function below - # TODO: test for off-by-one error - # TODO: we need to disable validation if possible, and separately call validation + # TODO: test for off-by-one error # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases @@ -166,79 +165,59 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") + # double check + if val_completed != 0.0: + raise ValueError(f"Tried to train only, but got a non-zero amount ({val_completed}) of validation done.") + - - # 3. Load metrics from checkpoint - (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkopint each epoch, so we select the most recent epoch - metrics = {'train_loss': all_tr_losses[-1], - 'val_eval': all_val_eval_metrics[-1]} + # 3. Load metrics from checkpoint + (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkpoint each call to train_nnunet, so it is critical that we are grabbing this right after the call above + metrics = {'train_loss': all_tr_losses[-1]} ###################################################################################################### - # TODO: Provide train_completed and val_completed to be incorporated into the collab weight computation + # TODO: Provide train_completed to be incorporated into the collab weight computation ###################################################################################################### return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - - + def validate(self, col_name, round_num, input_tensor_dict, **kwargs): - """ - Run the trained model on validation data; report results. - - Parameters - ---------- - input_tensor_dict : either the last aggregated or locally trained model - - Returns - ------- - output_tensor_dict : {TensorKey: nparray} (these correspond to acc, - precision, f1_score, etc.) - """ - - raise NotImplementedError() - - """ - TBD - for now commenting out - - self.rebuild_model(round_num, input_tensor_dict, validation=True) - - # 1. Save model in native format - self.save_native(self.mlcube_model_in_path) - - # 2. Call MLCube validate task - platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) - task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') - proc = subprocess.run(["mlcube_docker", - "run", - "--mlcube={}".format(self.mlcube_dir), - "--platform={}".format(platform_yaml), - "--task={}".format(task_yaml)]) + # TODO: Figure out the right name to use for this method and the default assigner + """Perform validation.""" - # 3. Load any metrics - metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + # 2. Train/val function existing externally + # Some todo inside function below + # TODO: test for off-by-one error + + # FIXME: we need to understand how to use round_num instead of current_epoch + # this will matter in straggler handling cases + # TODO: Should we put this in a separate process? + train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + epochs=epochs, + current_epoch=current_epoch, + train_cutoff=0, + val_cutoff = self.val_cutoff, + task=self.data_loader.get_task_name()) + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") - # set the validation data size - sample_count = int(metrics.pop(self.evaluation_sample_count_key)) - self.data_loader.set_valid_data_size(sample_count) + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # 3. Load metrics from checkpoint + (_, all_val_losses, _, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkopint each call to train_nnunet, so it is critical that we are grabbing this right after the call above + metrics = {'val_eval': all_val_eval_metrics[-1]} - # 4. Convert to tensorkeys - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' - else: - suffix += '_agg' - tags = ('metric', suffix) - output_tensor_dict = { - TensorKey( - metric_name, origin, round_num, True, tags - ): np.array(metrics[metric_name]) - for metric_name in metrics - } - - return output_tensor_dict, {} - """ + ###################################################################################################### + # TODO: Provide val_completed to be incorporated into the collab weight computation + ###################################################################################################### + return self.convert_results_to_tensorkeys(col_name, round_num, metrics) def load_metrics(self, filepath): From ad014dbe2d21eccd276b487aa03897dfca81a679 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 12:24:22 -0700 Subject: [PATCH 132/242] Revert "allowing val_cutoff to pass through to nnunet training function" This reverts commit b7970cfe4bd8f33f73244c7df98dca0a0b8759c0. --- .../fl_post/fl/project/src/runner_nnunetv1.py | 109 +++++++++++------- 1 file changed, 65 insertions(+), 44 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 11c513079..ee4df59c7 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -149,9 +149,10 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - # 2. Train/val function existing externally + # 2. Train function existing externally # Some todo inside function below - # TODO: test for off-by-one error + # TODO: test for off-by-one error + # TODO: we need to disable validation if possible, and separately call validation # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases @@ -165,59 +166,79 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") - # double check - if val_completed != 0.0: - raise ValueError(f"Tried to train only, but got a non-zero amount ({val_completed}) of validation done.") - - # 3. Load metrics from checkpoint - (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkpoint each call to train_nnunet, so it is critical that we are grabbing this right after the call above - metrics = {'train_loss': all_tr_losses[-1]} + + # 3. Load metrics from checkpoint + (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkopint each epoch, so we select the most recent epoch + metrics = {'train_loss': all_tr_losses[-1], + 'val_eval': all_val_eval_metrics[-1]} ###################################################################################################### - # TODO: Provide train_completed to be incorporated into the collab weight computation + # TODO: Provide train_completed and val_completed to be incorporated into the collab weight computation ###################################################################################################### return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - + + def validate(self, col_name, round_num, input_tensor_dict, **kwargs): - # TODO: Figure out the right name to use for this method and the default assigner - """Perform validation.""" + """ + Run the trained model on validation data; report results. - self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) - # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - # 2. Train/val function existing externally - # Some todo inside function below - # TODO: test for off-by-one error - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? - train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, - epochs=epochs, - current_epoch=current_epoch, - train_cutoff=0, - val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name()) - - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") + Parameters + ---------- + input_tensor_dict : either the last aggregated or locally trained model - # double check - if train_completed != 0.0: - raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") - - # 3. Load metrics from checkpoint - (_, all_val_losses, _, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkopint each call to train_nnunet, so it is critical that we are grabbing this right after the call above - metrics = {'val_eval': all_val_eval_metrics[-1]} + Returns + ------- + output_tensor_dict : {TensorKey: nparray} (these correspond to acc, + precision, f1_score, etc.) + """ + raise NotImplementedError() - ###################################################################################################### - # TODO: Provide val_completed to be incorporated into the collab weight computation - ###################################################################################################### - return self.convert_results_to_tensorkeys(col_name, round_num, metrics) + """ - TBD - for now commenting out + + self.rebuild_model(round_num, input_tensor_dict, validation=True) + + # 1. Save model in native format + self.save_native(self.mlcube_model_in_path) + + # 2. Call MLCube validate task + platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) + task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') + proc = subprocess.run(["mlcube_docker", + "run", + "--mlcube={}".format(self.mlcube_dir), + "--platform={}".format(platform_yaml), + "--task={}".format(task_yaml)]) + + # 3. Load any metrics + metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) + + # set the validation data size + sample_count = int(metrics.pop(self.evaluation_sample_count_key)) + self.data_loader.set_valid_data_size(sample_count) + + # 4. Convert to tensorkeys + + origin = col_name + suffix = 'validate' + if kwargs['apply'] == 'local': + suffix += '_local' + else: + suffix += '_agg' + tags = ('metric', suffix) + output_tensor_dict = { + TensorKey( + metric_name, origin, round_num, True, tags + ): np.array(metrics[metric_name]) + for metric_name in metrics + } + + return output_tensor_dict, {} + + """ def load_metrics(self, filepath): From 2530189d7f9e71386166547a6d8d1c527dbd0437 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 12:59:08 -0700 Subject: [PATCH 133/242] implementing separate train and val tasks --- .../fl_post/fl/project/src/runner_nnunetv1.py | 111 +++++++----------- 1 file changed, 45 insertions(+), 66 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index ee4df59c7..3f5b238e9 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -149,10 +149,9 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - # 2. Train function existing externally + # 2. Train/val function existing externally # Some todo inside function below - # TODO: test for off-by-one error - # TODO: we need to disable validation if possible, and separately call validation + # TODO: test for off-by-one error # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases @@ -160,85 +159,65 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, - train_cutoff=self.train_cutoff, + train_cutoff=0, val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name()) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") + # double check + if val_completed != 0.0: + raise ValueError(f"Tried to train only, but got a non-zero amount ({val_completed}) of validation done.") + - - # 3. Load metrics from checkpoint - (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkopint each epoch, so we select the most recent epoch - metrics = {'train_loss': all_tr_losses[-1], - 'val_eval': all_val_eval_metrics[-1]} + # 3. Load metrics from checkpoint + (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkpoint each call to train_nnunet, so it is critical that we are grabbing this right after the call above + metrics = {'train_loss': all_tr_losses[-1]} ###################################################################################################### - # TODO: Provide train_completed and val_completed to be incorporated into the collab weight computation + # TODO: Provide train_completed to be incorporated into the collab weight computation ###################################################################################################### return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - - + def validate(self, col_name, round_num, input_tensor_dict, **kwargs): - """ - Run the trained model on validation data; report results. - - Parameters - ---------- - input_tensor_dict : either the last aggregated or locally trained model - - Returns - ------- - output_tensor_dict : {TensorKey: nparray} (these correspond to acc, - precision, f1_score, etc.) - """ - - raise NotImplementedError() - - """ - TBD - for now commenting out - - self.rebuild_model(round_num, input_tensor_dict, validation=True) - - # 1. Save model in native format - self.save_native(self.mlcube_model_in_path) - - # 2. Call MLCube validate task - platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) - task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') - proc = subprocess.run(["mlcube_docker", - "run", - "--mlcube={}".format(self.mlcube_dir), - "--platform={}".format(platform_yaml), - "--task={}".format(task_yaml)]) + # TODO: Figure out the right name to use for this method and the default assigner + """Perform validation.""" - # 3. Load any metrics - metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + # 2. Train/val function existing externally + # Some todo inside function below + # TODO: test for off-by-one error + + # FIXME: we need to understand how to use round_num instead of current_epoch + # this will matter in straggler handling cases + # TODO: Should we put this in a separate process? + train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + epochs=epochs, + current_epoch=current_epoch, + train_cutoff=0, + val_cutoff = self.val_cutoff, + task=self.data_loader.get_task_name()) + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") - # set the validation data size - sample_count = int(metrics.pop(self.evaluation_sample_count_key)) - self.data_loader.set_valid_data_size(sample_count) + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # 3. Load metrics from checkpoint + (_, all_val_losses, _, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkopint each call to train_nnunet, so it is critical that we are grabbing this right after the call above + metrics = {'val_eval': all_val_eval_metrics[-1]} - # 4. Convert to tensorkeys - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' - else: - suffix += '_agg' - tags = ('metric', suffix) - output_tensor_dict = { - TensorKey( - metric_name, origin, round_num, True, tags - ): np.array(metrics[metric_name]) - for metric_name in metrics - } - - return output_tensor_dict, {} - """ + ###################################################################################################### + # TODO: Provide val_completed to be incorporated into the collab weight computation + ###################################################################################################### + return self.convert_results_to_tensorkeys(col_name, round_num, metrics) def load_metrics(self, filepath): From b096e3dc991ef5585f768bc44a68c21a54460b5a Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 13:17:33 -0700 Subject: [PATCH 134/242] correcting timeouts for train and val tasks --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- examples/fl_post/fl/setup_clean.sh | 9 +++++++++ examples/fl_post/fl/test.sh | 4 ++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 3f5b238e9..41d2e3e1f 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -159,8 +159,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, - train_cutoff=0, - val_cutoff = self.val_cutoff, + train_cutoff=self.train_cutoff, + val_cutoff = 0, task=self.data_loader.get_task_name()) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") diff --git a/examples/fl_post/fl/setup_clean.sh b/examples/fl_post/fl/setup_clean.sh index 6615c2968..5db13b158 100644 --- a/examples/fl_post/fl/setup_clean.sh +++ b/examples/fl_post/fl/setup_clean.sh @@ -1,6 +1,15 @@ + +HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" + +cd $HOMEDIR + + + rm -rf ./mlcube_agg rm -rf ./mlcube_col1 rm -rf ./mlcube_col2 rm -rf ./mlcube_col3 +rm -rf ./mlcube_col4 +rm -rf ./mlcube_col5 rm -rf ./ca rm -rf ./for_admin diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index d70d35c3e..479be18d8 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -1,6 +1,10 @@ export HTTPS_PROXY= export http_proxy= +HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" + +cd $HOMEDIR + # generate plan and copy it to each node GENERATE_PLAN_PLATFORM="docker" AGG_PLATFORM="docker" From e04cc15a7be91626cf555db6fd9e8bbbd7f04e2f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 13:25:33 -0700 Subject: [PATCH 135/242] five collaborators --- examples/fl_post/fl/test.sh | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 479be18d8..3fd377a50 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -8,9 +8,11 @@ cd $HOMEDIR # generate plan and copy it to each node GENERATE_PLAN_PLATFORM="docker" AGG_PLATFORM="docker" -COL1_PLATFORM="singularity" +COL1_PLATFORM="docker" COL2_PLATFORM="docker" COL3_PLATFORM="docker" +COL4_PLATFORM="docker" +COL5_PLATFORM="docker" medperf --platform $GENERATE_PLAN_PLATFORM mlcube run --mlcube ./mlcube_agg --task generate_plan mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace @@ -18,6 +20,8 @@ rm -r ./mlcube_agg/workspace/plan cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col4/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col5/workspace cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes @@ -25,6 +29,8 @@ AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task st COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" COL2="medperf --platform $COL2_PLATFORM --gpus=device=2 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +COL4="medperf --platform $COL4_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col4 --task train -e MEDPERF_PARTICIPANT_LABEL=col4@example.com" +COL5="medperf --platform $COL5_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col5 --task train -e MEDPERF_PARTICIPANT_LABEL=col5@example.com" # medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & @@ -35,13 +41,16 @@ COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./ml rm agg.log col1.log col2.log col3.log $AGG >>agg.log & sleep 6 +$COL1 >>col1.log & +sleep 6 $COL2 >>col2.log & sleep 6 -$COL3 >>col3.log & -# sleep 6 -# $COL2 >> col2.log & -# sleep 6 -# $COL3 >> col3.log & +$COL3 >> col3.log & +sleep 6 +$COL4 >> col4.log & +sleep 6 +$COL5 >> col5.log & + wait # docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 From 0829fb7207f456ea8ca4c934d42284791502ef3e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 15:02:11 -0700 Subject: [PATCH 136/242] changing fl plan file to indicate which model to apply for local vs global validation --- .../fl_post/fl/be_setup_test_no_docker.sh | 249 ++++++++++++++++++ examples/fl_post/fl/intel_build.sh | 19 ++ .../fl/mlcube/workspace/training_config.yaml | 2 + examples/fl_post/fl/test.sh | 2 +- 4 files changed, 271 insertions(+), 1 deletion(-) create mode 100644 examples/fl_post/fl/be_setup_test_no_docker.sh create mode 100755 examples/fl_post/fl/intel_build.sh diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh new file mode 100644 index 000000000..b1f99c8ad --- /dev/null +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -0,0 +1,249 @@ +while getopts t flag; do + case "${flag}" in + t) TWO_COL_SAME_CERT="true" ;; + esac +done +TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" + +COL1_CN="col1@example.com" +COL2_CN="col2@example.com" +COL3_CN="col3@example.com" +COL4_CN="col4@example.com" +COL5_CN="col5@example.com" + +COL1_LABEL="col1@example.com" +COL2_LABEL="col2@example.com" +COL3_LABEL="col3@example.com" +COL4_LABEL="col4@example.com" +COL5_LABEL="col5@example.com" + +if ${TWO_COL_SAME_CERT}; then + COL1_CN="org1@example.com" + COL2_CN="org2@example.com" + COL3_CN="org3@example.com" + COL4_CN="org4@example.com" + COL5_CN="org5@example.com" # in this case this var is not used actually. +fi + +HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" + +cd $HOMEDIR + +mkdir mlcube_agg +mkdir mlcube_col1 +mkdir mlcube_col2 +mkdir mlcube_col3 +mkdir mlcube_col4 +mkdir mlcube_col5 + + + +cp -r ./mlcube/* ./mlcube_agg +cp -r ./mlcube/* ./mlcube_col1 +cp -r ./mlcube/* ./mlcube_col2 +cp -r ./mlcube/* ./mlcube_col3 +cp -r ./mlcube/* ./mlcube_col4 +cp -r ./mlcube/* ./mlcube_col5 + +mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert +mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert +mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert +mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert +mkdir ./mlcube_col4/workspace/node_cert ./mlcube_col4/workspace/ca_cert +mkdir ./mlcube_col5/workspace/node_cert ./mlcube_col5/workspace/ca_cert +mkdir ca + +HOSTNAME_=$(hostname -A | cut -d " " -f 1) +# HOSTNAME_=$(hostname -I | cut -d " " -f 1) + + +# root ca +openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 +openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ + -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" + +# col1 +sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf +cd mlcube_col1/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd $HOMEDIR + +# col2 +sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf +cd mlcube_col2/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd $HOMEDIR + +# col3 +if ${TWO_COL_SAME_CERT}; then + never goes here cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert + cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert +else + sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf + sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf + cd mlcube_col3/workspace/node_cert + openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 + openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client + openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf + rm csr.csr + cp ../../../ca/root.crt ../ca_cert/ + cd $HOMEDIR +fi + +# col4 +sed -i "/^commonName = /c\commonName = $COL4_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL4_CN" csr.conf +cd mlcube_col4/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd $HOMEDIR + +# col5 +sed -i "/^commonName = /c\commonName = $COL5_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $COL5_CN" csr.conf +cd mlcube_col5/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd $HOMEDIR + + + +# agg +sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf +cd mlcube_agg/workspace/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server +openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_server_crt -extfile ../../../csr.conf +rm csr.csr +cp ../../../ca/root.crt ../ca_cert/ +cd $HOMEDIR + +# aggregator_config +echo "address: $HOSTNAME_" >> mlcube_agg/workspace/aggregator_config.yaml +echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml + +# cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL4_LABEL: $COL4_CN" >>mlcube_agg/workspace/cols.yaml +echo "$COL5_LABEL: $COL5_CN" >>mlcube_agg/workspace/cols.yaml + +# for admin +ADMIN_CN="admin@example.com" + +mkdir ./for_admin +mkdir ./for_admin/node_cert + +sed -i "/^commonName = /c\commonName = $ADMIN_CN" csr.conf +sed -i "/^DNS\.1 = /c\DNS.1 = $ADMIN_CN" csr.conf +cd for_admin/node_cert +openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 +openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3_client +openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ + -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf +rm csr.csr +mkdir ../ca_cert +cp -r ../../ca/root.crt ../ca_cert/root.crt +cd $HOMEDIR + +# THIS IS BRANDON'S CODE COPYING IN THE SAME DATA +mkdir mlcube_col1/workspace/labels +mkdir mlcube_col1/workspace/data + +mkdir mlcube_col2/workspace/labels +mkdir mlcube_col2/workspace/data + +mkdir mlcube_col3/workspace/labels +mkdir mlcube_col3/workspace/data + +mkdir mlcube_col4/workspace/labels +mkdir mlcube_col4/workspace/data + +mkdir mlcube_col5/workspace/labels +mkdir mlcube_col5/workspace/data + +# DATA_DIR="test_data_links_testforhasan" +# DATA_DIR="test_data_links_random_times_0" +# DATA_DIR="test_data_links" + +# this is the one I had success running on +#DATA_DIRS=test_data_small_from_hasan + +SIZE="hundred" +#SIZE="thousand" + +DATA_DIR_1="test_${SIZE}_BraTS20_3square_0" +DATA_DIR_2="test_${SIZE}_BraTS20_3square_1" +DATA_DIR_3="test_${SIZE}_BraTS20_3square_2" +DATA_DIR_4="test_${SIZE}_BraTS20_3square_3" +DATA_DIR_5="test_${SIZE}_BraTS20_3square_4" + +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_1/labels/* mlcube_col1/workspace/labels +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_1/data/* mlcube_col1/workspace/data + +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_2/labels/* mlcube_col2/workspace/labels +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_2/data/* mlcube_col2/workspace/data + +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_3/labels/* mlcube_col3/workspace/labels +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_3/data/* mlcube_col3/workspace/data + +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_4/labels/* mlcube_col4/workspace/labels +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_4/data/* mlcube_col4/workspace/data + +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_5/labels/* mlcube_col5/workspace/labels +cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_5/data/* mlcube_col5/workspace/data + +# wget https://storage.googleapis.com/medperf-storage/fltest29July/flpost_add29july.tar.gz I copied on spr01 into /home/edwardsb/repo_extras/hasan_medperperf_extras + +# aggregator additional files +mkdir mlcube_agg/workspace/additional_files +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_weights mlcube_agg/workspace/additional_files +# maybe I don't need the one immediately below (only for collaborators) +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_agg/workspace/additional_files + +# col1 additional files +mkdir mlcube_col1/workspace/additional_files +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col1/workspace/additional_files + +# col2 additional files +mkdir mlcube_col2/workspace/additional_files +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col2/workspace/additional_files + +# col3 additional files +mkdir mlcube_col3/workspace/additional_files +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col3/workspace/additional_files + +# col4 additional files +mkdir mlcube_col4/workspace/additional_files +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col4/workspace/additional_files + +# col5 additional files +mkdir mlcube_col5/workspace/additional_files +cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col5/workspace/additional_files + + +# source /home/edwardsb/virtual/hasan_medperf/bin/activate diff --git a/examples/fl_post/fl/intel_build.sh b/examples/fl_post/fl/intel_build.sh new file mode 100755 index 000000000..3a58ee319 --- /dev/null +++ b/examples/fl_post/fl/intel_build.sh @@ -0,0 +1,19 @@ +while getopts b flag; do + case "${flag}" in + b) BUILD_BASE="true" ;; + esac +done +BUILD_BASE="${BUILD_BASE:-false}" + +if ${BUILD_BASE}; then + git clone https://github.com/hasan7n/openfl.git + cd openfl + git checkout 54f27c61c274f64af3d028f962f62392419cb67e + docker build \ + --build-arg http_proxy="http://proxy-us.intel.com:912" \ + --build-arg https_proxy="http://proxy-us.intel.com:912" \ + -t local/openfl:local -f openfl-docker/Dockerfile.base . + cd .. + rm -rf openfl +fi +mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always -Pdocker.build_args="--build-arg http_proxy='http://proxy-us.intel.com:912' --build-arg https_proxy='http://proxy-us.intel.com:912'" diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 793383cbf..d6cec6ad9 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -57,6 +57,7 @@ tasks : function : validate kwargs : metrics : + apply : global - val_eval epochs : 1 train: @@ -70,6 +71,7 @@ tasks : kwargs : metrics : - val_eval + apply : local epochs : 1 compression_pipeline : diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 3fd377a50..911ed987f 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -38,7 +38,7 @@ COL5="medperf --platform $COL5_PLATFORM --gpus=device=1 mlcube run --mlcube ./ml # gnome-terminal -- bash -c "$COL1; bash" # gnome-terminal -- bash -c "$COL2; bash" # gnome-terminal -- bash -c "$COL3; bash" -rm agg.log col1.log col2.log col3.log +rm agg.log col1.log col2.log col3.log col4.log col5.log $AGG >>agg.log & sleep 6 $COL1 >>col1.log & From b59c2dc1595f4e257ea4659bc2505796eba580df Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 15:18:09 -0700 Subject: [PATCH 137/242] more of my local files,and changing be_setup.. script to copy over changes from hasan_medperf code --- examples/fl_post/fl/be_setup_test_no_docker.sh | 15 +++++++++------ examples/fl_post/fl/project/Dockerfile | 2 ++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index b1f99c8ad..d31157285 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -25,6 +25,9 @@ if ${TWO_COL_SAME_CERT}; then COL5_CN="org5@example.com" # in this case this var is not used actually. fi + +CODE_CHANGE_DIR="/home/edwardsb/repositories/hasan_medperf/examples/fl_post/fl" + HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" cd $HOMEDIR @@ -38,12 +41,12 @@ mkdir mlcube_col5 -cp -r ./mlcube/* ./mlcube_agg -cp -r ./mlcube/* ./mlcube_col1 -cp -r ./mlcube/* ./mlcube_col2 -cp -r ./mlcube/* ./mlcube_col3 -cp -r ./mlcube/* ./mlcube_col4 -cp -r ./mlcube/* ./mlcube_col5 +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_agg +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col1 +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col2 +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col3 +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col4 +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col5 mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile index c5e1ef2ee..f48a3caa0 100644 --- a/examples/fl_post/fl/project/Dockerfile +++ b/examples/fl_post/fl/project/Dockerfile @@ -6,6 +6,8 @@ ENV CUDA_VISIBLE_DEVICES="0" # ENV https_proxy="http://proxy-us.intel.com:912" ENV no_proxy=localhost,spr-gpu01.jf.intel.com +ENV no_proxy__="http://proxy-us.intel.com:912" + # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ From c0b9f51cdb29278d7b5342c6dc85acbdf79679c3 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 15:46:10 -0700 Subject: [PATCH 138/242] correcting training config --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index d6cec6ad9..3f82323ee 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -55,11 +55,10 @@ tasks : defaults : plan/defaults/tasks_torch.yaml aggregated_model_validation: function : validate - kwargs : + kwargs : metrics : - apply : global - val_eval - epochs : 1 + apply : global train: function : train kwargs : @@ -72,7 +71,6 @@ tasks : metrics : - val_eval apply : local - epochs : 1 compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml From c8c1b3c4edfa0129d7d9dff44db7a1379d75a33e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 18:24:07 -0700 Subject: [PATCH 139/242] some more test infrastructure --- examples/fl_post/fl/be_setup_test_no_docker.sh | 1 + examples/fl_post/fl/mlcube/workspace/training_config.yaml | 6 +++--- examples/fl_post/fl/project/Dockerfile | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index d31157285..641f8b4c3 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -41,6 +41,7 @@ mkdir mlcube_col5 +cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_agg cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col1 cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col2 diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 3f82323ee..5538220c8 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -57,19 +57,19 @@ tasks : function : validate kwargs : metrics : - - val_eval + - val_eval apply : global train: function : train kwargs : metrics : - - train_loss + - train_loss epochs : 1 locally_tuned_model_validation: function : validate kwargs : metrics : - - val_eval + - val_eval apply : local compression_pipeline : diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile index f48a3caa0..745c4507f 100644 --- a/examples/fl_post/fl/project/Dockerfile +++ b/examples/fl_post/fl/project/Dockerfile @@ -6,7 +6,7 @@ ENV CUDA_VISIBLE_DEVICES="0" # ENV https_proxy="http://proxy-us.intel.com:912" ENV no_proxy=localhost,spr-gpu01.jf.intel.com -ENV no_proxy__="http://proxy-us.intel.com:912" +ENV no_proxy_______="http://proxy-us.intel.com:912" # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ From f99a0d79b5bdd42c7dab99f9b35100cbb89afcc7 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 18:38:41 -0700 Subject: [PATCH 140/242] need to pass epochs value of 1 to nnunet train function used for validation task --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 41d2e3e1f..7293f0cec 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -196,7 +196,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): # this will matter in straggler handling cases # TODO: Should we put this in a separate process? train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, - epochs=epochs, + epochs=1, current_epoch=current_epoch, train_cutoff=0, val_cutoff = self.val_cutoff, From 502bd85c0d904aa30bbe1a6429b4cdbe98e8d08f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 18:48:05 -0700 Subject: [PATCH 141/242] now copying over whole repo into raid homedir --- examples/fl_post/fl/be_setup_test_no_docker.sh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index 641f8b4c3..19713ae78 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -30,6 +30,8 @@ CODE_CHANGE_DIR="/home/edwardsb/repositories/hasan_medperf/examples/fl_post/fl" HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" +cp -r $CODE_CHANGE_DIR/* $HOMEDIR + cd $HOMEDIR mkdir mlcube_agg @@ -41,13 +43,13 @@ mkdir mlcube_col5 -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_agg -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col1 -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col2 -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col3 -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col4 -cp -r $CODE_CHANGE_DIR/mlcube/* ./mlcube_col5 +cp -r ./mlcube/* ./mlcube +cp -r ./mlcube/* ./mlcube_agg +cp -r ./mlcube/* ./mlcube_col1 +cp -r ./mlcube/* ./mlcube_col2 +cp -r ./mlcube/* ./mlcube_col3 +cp -r ./mlcube/* ./mlcube_col4 +cp -r ./mlcube/* ./mlcube_col5 mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert From 25ae1e42ddbf344417f1b1f01f5babc0d9760f28 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 19:24:50 -0700 Subject: [PATCH 142/242] another testing change --- examples/fl_post/fl/be_setup_test_no_docker.sh | 2 +- examples/fl_post/fl/project/Dockerfile | 2 +- examples/fl_post/fl/project/requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index 19713ae78..73629f8db 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -30,7 +30,7 @@ CODE_CHANGE_DIR="/home/edwardsb/repositories/hasan_medperf/examples/fl_post/fl" HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" -cp -r $CODE_CHANGE_DIR/* $HOMEDIR +# cp -r $CODE_CHANGE_DIR/* $HOMEDIR cd $HOMEDIR diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile index 745c4507f..1ba403e20 100644 --- a/examples/fl_post/fl/project/Dockerfile +++ b/examples/fl_post/fl/project/Dockerfile @@ -6,7 +6,7 @@ ENV CUDA_VISIBLE_DEVICES="0" # ENV https_proxy="http://proxy-us.intel.com:912" ENV no_proxy=localhost,spr-gpu01.jf.intel.com -ENV no_proxy_______="http://proxy-us.intel.com:912" +ENV no_proxy________="http://proxy-us.intel.com:912" # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ diff --git a/examples/fl_post/fl/project/requirements.txt b/examples/fl_post/fl/project/requirements.txt index 8f03308f9..ef2281ff7 100644 --- a/examples/fl_post/fl/project/requirements.txt +++ b/examples/fl_post/fl/project/requirements.txt @@ -1,4 +1,4 @@ onnx==1.13.0 typer==0.9.0 -git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@main#egg=nnunet +git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@supporting_partial_epochs numpy==1.26.4 From bdaffbc0d93fb1527d8ee8152f232cfa2e341174 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 19:55:56 -0700 Subject: [PATCH 143/242] now using validation_only parameter --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 7293f0cec..02b77f738 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -200,7 +200,8 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): current_epoch=current_epoch, train_cutoff=0, val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name()) + task=self.data_loader.get_task_name(), + validation_only=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") From ca8050070086ea9bf663ee0569e21012befe5e96 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 20:03:58 -0700 Subject: [PATCH 144/242] allowing at most 1 second over one batch of val (during training) to avoid NNUnet code throwing exception due to empty val results --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 02b77f738..9af505f51 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -156,20 +156,17 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? + # TODO: Currently allowing at most 1 second of valiation over one batch in order to avoid NNUnet code throwing exception due + # to empty val results train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, - val_cutoff = 0, + val_cutoff = 1, task=self.data_loader.get_task_name()) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") - # double check - if val_completed != 0.0: - raise ValueError(f"Tried to train only, but got a non-zero amount ({val_completed}) of validation done.") - - # 3. Load metrics from checkpoint (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] # these metrics are appended to the checkpoint each call to train_nnunet, so it is critical that we are grabbing this right after the call above From 4b9d6048c05492f70c496ad0e2bc29766cb9e8a6 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 30 Sep 2024 20:24:26 -0700 Subject: [PATCH 145/242] need to account for loaders being different when validate_only is used --- examples/fl_post/fl/project/src/nnunet_v1.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 7e669254a..60b6362d7 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -240,8 +240,12 @@ def __init__(self, **kwargs): # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times - num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) - num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + if validate_only: + num_train_batches_per_epoch = 0 + num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + else: + num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) + num_val_batches_per_epoch = 0 # the nnunet trainer attributes have a different naming convention than I am using trainer.num_batches_per_epoch = num_train_batches_per_epoch @@ -276,8 +280,12 @@ def __init__(self, **kwargs): # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() - train_completed = batches_applied_train / float(num_train_batches_per_epoch) - val_completed = batches_applied_val / float(num_val_batches_per_epoch) + if validate_only: + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = 0 + else: + train_completed = 0 + val_completed = batches_applied_val / float(num_val_batches_per_epoch) return train_completed, val_completed From ac45d054c319c29ca312244153bfa00a5f1d8f2a Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 2 Oct 2024 09:00:50 -0700 Subject: [PATCH 146/242] typo 'validate_only' --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 60b6362d7..1bc85b6ed 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -240,7 +240,7 @@ def __init__(self, **kwargs): # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times - if validate_only: + if validation_only: num_train_batches_per_epoch = 0 num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) else: @@ -280,7 +280,7 @@ def __init__(self, **kwargs): # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() - if validate_only: + if validation_only: train_completed = batches_applied_train / float(num_train_batches_per_epoch) val_completed = 0 else: From a194f6b4a5e3886e8b6fdf7c8063a16ff2200208 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 2 Oct 2024 11:36:39 -0700 Subject: [PATCH 147/242] local setup script changes --- examples/fl_post/fl/be_setup_test_no_docker.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index 73629f8db..19713ae78 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -30,7 +30,7 @@ CODE_CHANGE_DIR="/home/edwardsb/repositories/hasan_medperf/examples/fl_post/fl" HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" -# cp -r $CODE_CHANGE_DIR/* $HOMEDIR +cp -r $CODE_CHANGE_DIR/* $HOMEDIR cd $HOMEDIR From e6fa758cc439459fcb4b9c6596cb39b28b797171 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 2 Oct 2024 12:16:53 -0700 Subject: [PATCH 148/242] moving back to not using validation_only in order that we still have access to training data for calibrating data size --- examples/fl_post/fl/project/src/nnunet_v1.py | 16 ++++------------ .../fl_post/fl/project/src/runner_nnunetv1.py | 6 +++--- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 1bc85b6ed..77456c4de 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -240,12 +240,8 @@ def __init__(self, **kwargs): # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times - if validation_only: - num_train_batches_per_epoch = 0 - num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) - else: - num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) - num_val_batches_per_epoch = 0 + num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) # the nnunet trainer attributes have a different naming convention than I am using trainer.num_batches_per_epoch = num_train_batches_per_epoch @@ -280,12 +276,8 @@ def __init__(self, **kwargs): # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() - if validation_only: - train_completed = batches_applied_train / float(num_train_batches_per_epoch) - val_completed = 0 - else: - train_completed = 0 - val_completed = batches_applied_val / float(num_val_batches_per_epoch) + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = batches_applied_val / float(num_val_batches_per_epoch) return train_completed, val_completed diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 9af505f51..4dde6de8e 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -184,7 +184,8 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + # offseting current_epoch in order that it is set back to where it was in previous checkpoint after the rain call + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - 1 # 2. Train/val function existing externally # Some todo inside function below # TODO: test for off-by-one error @@ -197,8 +198,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): current_epoch=current_epoch, train_cutoff=0, val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name(), - validation_only=True) + task=self.data_loader.get_task_name()) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") From 1f05091d671b1301df52f2e7736268b04602aad8 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 09:08:57 -0700 Subject: [PATCH 149/242] preping to move all local stuff over the files named with be_... --- examples/fl_post/fl/be_setup_clean.sh | 15 +++++ examples/fl_post/fl/be_test.sh | 57 +++++++++++++++++++ examples/fl_post/fl/intel_build.sh | 3 + examples/fl_post/fl/project/be_Dockerfile | 29 ++++++++++ .../fl_post/fl/project/src/runner_nnunetv1.py | 5 +- 5 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 examples/fl_post/fl/be_setup_clean.sh create mode 100755 examples/fl_post/fl/be_test.sh create mode 100644 examples/fl_post/fl/project/be_Dockerfile diff --git a/examples/fl_post/fl/be_setup_clean.sh b/examples/fl_post/fl/be_setup_clean.sh new file mode 100644 index 000000000..5db13b158 --- /dev/null +++ b/examples/fl_post/fl/be_setup_clean.sh @@ -0,0 +1,15 @@ + +HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" + +cd $HOMEDIR + + + +rm -rf ./mlcube_agg +rm -rf ./mlcube_col1 +rm -rf ./mlcube_col2 +rm -rf ./mlcube_col3 +rm -rf ./mlcube_col4 +rm -rf ./mlcube_col5 +rm -rf ./ca +rm -rf ./for_admin diff --git a/examples/fl_post/fl/be_test.sh b/examples/fl_post/fl/be_test.sh new file mode 100755 index 000000000..911ed987f --- /dev/null +++ b/examples/fl_post/fl/be_test.sh @@ -0,0 +1,57 @@ +export HTTPS_PROXY= +export http_proxy= + +HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" + +cd $HOMEDIR + +# generate plan and copy it to each node +GENERATE_PLAN_PLATFORM="docker" +AGG_PLATFORM="docker" +COL1_PLATFORM="docker" +COL2_PLATFORM="docker" +COL3_PLATFORM="docker" +COL4_PLATFORM="docker" +COL5_PLATFORM="docker" + +medperf --platform $GENERATE_PLAN_PLATFORM mlcube run --mlcube ./mlcube_agg --task generate_plan +mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace +rm -r ./mlcube_agg/workspace/plan +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col4/workspace +cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col5/workspace +cp ./mlcube_agg/workspace/plan.yaml ./for_admin + +# Run nodes +AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" +COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" +COL2="medperf --platform $COL2_PLATFORM --gpus=device=2 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +COL4="medperf --platform $COL4_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col4 --task train -e MEDPERF_PARTICIPANT_LABEL=col4@example.com" +COL5="medperf --platform $COL5_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col5 --task train -e MEDPERF_PARTICIPANT_LABEL=col5@example.com" + +# medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & + +# gnome-terminal -- bash -c "$AGG; bash" +# gnome-terminal -- bash -c "$COL1; bash" +# gnome-terminal -- bash -c "$COL2; bash" +# gnome-terminal -- bash -c "$COL3; bash" +rm agg.log col1.log col2.log col3.log col4.log col5.log +$AGG >>agg.log & +sleep 6 +$COL1 >>col1.log & +sleep 6 +$COL2 >>col2.log & +sleep 6 +$COL3 >> col3.log & +sleep 6 +$COL4 >> col4.log & +sleep 6 +$COL5 >> col5.log & + +wait + +# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 +# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/intel_build.sh b/examples/fl_post/fl/intel_build.sh index 3a58ee319..54494127a 100755 --- a/examples/fl_post/fl/intel_build.sh +++ b/examples/fl_post/fl/intel_build.sh @@ -5,6 +5,9 @@ while getopts b flag; do done BUILD_BASE="${BUILD_BASE:-false}" +# copy over changes from be_Dockerfile to Dockerfile +cp ./project/be_Dockerfile ./project/Dockerfile + if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl diff --git a/examples/fl_post/fl/project/be_Dockerfile b/examples/fl_post/fl/project/be_Dockerfile new file mode 100644 index 000000000..b69f527c2 --- /dev/null +++ b/examples/fl_post/fl/project/be_Dockerfile @@ -0,0 +1,29 @@ +FROM local/openfl:local + +ENV LANG C.UTF-8 +ENV CUDA_VISIBLE_DEVICES="0" +# ENV http_proxy="http://proxy-us.intel.com:912" +# ENV https_proxy="http://proxy-us.intel.com:912" +ENV no_proxy=localhost,spr-gpu01.jf.intel.com + +ENV no_proxy________________="http://proxy-us.intel.com:912" + +# install project dependencies +RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ + +RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 + +COPY ./requirements.txt /mlcube_project/requirements.txt +RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Create similar env with cuda118 +RUN apt-get update && apt-get install python3.10-venv -y +RUN python -m venv /cuda118 +RUN /cuda118/bin/pip install --no-cache-dir /openfl +RUN /cuda118/bin/pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 +RUN /cuda118/bin/pip install --no-cache-dir -r /mlcube_project/requirements.txt + +# Copy mlcube project folder +COPY . /mlcube_project + +ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 4dde6de8e..b1d4d5c54 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -158,6 +158,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Should we put this in a separate process? # TODO: Currently allowing at most 1 second of valiation over one batch in order to avoid NNUnet code throwing exception due # to empty val results + print(f"Brandon DEBUG - about to call train_nnunet with:\nTOTAL_max_num_epochs:{self.TOTAL_max_num_epochs}\ntrain_cutoff:{self.train_cutoff}\nval_cutoff:1") train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, @@ -193,6 +194,8 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? + print(f"Brandon DEBUG - about to call train_nnunet with:\nTOTAL_max_num_epochs:{self.TOTAL_max_num_epochs}\ntrain_cutoff:0\nval_cutoff:{self.val_cutoff}") + print(f"Recall that you may be getting in trouble here due to train_cutoff exiting without val work being done?") train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=1, current_epoch=current_epoch, @@ -200,7 +203,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name()) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # double check if train_completed != 0.0: From 72603215730ecbc95b3dc1156ae48e5961495bce Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 09:27:46 -0700 Subject: [PATCH 150/242] putting back fl-poc versions of some files --- examples/fl_post/fl/project/Dockerfile | 6 +---- examples/fl_post/fl/setup_clean.sh | 9 -------- examples/fl_post/fl/test.sh | 32 +++++++------------------- 3 files changed, 9 insertions(+), 38 deletions(-) diff --git a/examples/fl_post/fl/project/Dockerfile b/examples/fl_post/fl/project/Dockerfile index 1ba403e20..d12baa7bb 100644 --- a/examples/fl_post/fl/project/Dockerfile +++ b/examples/fl_post/fl/project/Dockerfile @@ -2,11 +2,7 @@ FROM local/openfl:local ENV LANG C.UTF-8 ENV CUDA_VISIBLE_DEVICES="0" -# ENV http_proxy="http://proxy-us.intel.com:912" -# ENV https_proxy="http://proxy-us.intel.com:912" -ENV no_proxy=localhost,spr-gpu01.jf.intel.com -ENV no_proxy________="http://proxy-us.intel.com:912" # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ @@ -26,4 +22,4 @@ RUN /cuda118/bin/pip install --no-cache-dir -r /mlcube_project/requirements.txt # Copy mlcube project folder COPY . /mlcube_project -ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] +ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] \ No newline at end of file diff --git a/examples/fl_post/fl/setup_clean.sh b/examples/fl_post/fl/setup_clean.sh index 5db13b158..6615c2968 100644 --- a/examples/fl_post/fl/setup_clean.sh +++ b/examples/fl_post/fl/setup_clean.sh @@ -1,15 +1,6 @@ - -HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" - -cd $HOMEDIR - - - rm -rf ./mlcube_agg rm -rf ./mlcube_col1 rm -rf ./mlcube_col2 rm -rf ./mlcube_col3 -rm -rf ./mlcube_col4 -rm -rf ./mlcube_col5 rm -rf ./ca rm -rf ./for_admin diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 911ed987f..9463c56a1 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -1,18 +1,9 @@ -export HTTPS_PROXY= -export http_proxy= - -HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" - -cd $HOMEDIR - # generate plan and copy it to each node GENERATE_PLAN_PLATFORM="docker" AGG_PLATFORM="docker" -COL1_PLATFORM="docker" +COL1_PLATFORM="singularity" COL2_PLATFORM="docker" COL3_PLATFORM="docker" -COL4_PLATFORM="docker" -COL5_PLATFORM="docker" medperf --platform $GENERATE_PLAN_PLATFORM mlcube run --mlcube ./mlcube_agg --task generate_plan mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace @@ -20,17 +11,13 @@ rm -r ./mlcube_agg/workspace/plan cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col4/workspace -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col5/workspace cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf --platform $COL2_PLATFORM --gpus=device=2 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" +COL2="medperf --platform $COL2_PLATFORM --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" -COL4="medperf --platform $COL4_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col4 --task train -e MEDPERF_PARTICIPANT_LABEL=col4@example.com" -COL5="medperf --platform $COL5_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col5 --task train -e MEDPERF_PARTICIPANT_LABEL=col5@example.com" # medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & @@ -38,19 +25,16 @@ COL5="medperf --platform $COL5_PLATFORM --gpus=device=1 mlcube run --mlcube ./ml # gnome-terminal -- bash -c "$COL1; bash" # gnome-terminal -- bash -c "$COL2; bash" # gnome-terminal -- bash -c "$COL3; bash" -rm agg.log col1.log col2.log col3.log col4.log col5.log +rm agg.log col1.log col2.log col3.log $AGG >>agg.log & sleep 6 -$COL1 >>col1.log & -sleep 6 $COL2 >>col2.log & sleep 6 -$COL3 >> col3.log & -sleep 6 -$COL4 >> col4.log & -sleep 6 -$COL5 >> col5.log & - +$COL3 >>col3.log & +# sleep 6 +# $COL2 >> col2.log & +# sleep 6 +# $COL3 >> col3.log & wait # docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 From 180e5f0675fadcb5e392410fe446b09bda028441 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 10:32:38 -0700 Subject: [PATCH 151/242] changes to make sure validate method in nnunet runner does not save a new checkpoint with epoch one greater --- examples/fl_post/fl/be_test.sh | 6 +++--- examples/fl_post/fl/project/src/nnunet_v1.py | 6 +++++- examples/fl_post/fl/project/src/runner_nnunetv1.py | 6 +++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/fl_post/fl/be_test.sh b/examples/fl_post/fl/be_test.sh index 911ed987f..b4910d2d0 100755 --- a/examples/fl_post/fl/be_test.sh +++ b/examples/fl_post/fl/be_test.sh @@ -28,9 +28,9 @@ cp ./mlcube_agg/workspace/plan.yaml ./for_admin AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" COL2="medperf --platform $COL2_PLATFORM --gpus=device=2 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" -COL4="medperf --platform $COL4_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col4 --task train -e MEDPERF_PARTICIPANT_LABEL=col4@example.com" -COL5="medperf --platform $COL5_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col5 --task train -e MEDPERF_PARTICIPANT_LABEL=col5@example.com" +COL3="medperf --platform $COL3_PLATFORM --gpus=device=3 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +COL4="medperf --platform $COL4_PLATFORM --gpus=device=4 mlcube run --mlcube ./mlcube_col4 --task train -e MEDPERF_PARTICIPANT_LABEL=col4@example.com" +COL5="medperf --platform $COL5_PLATFORM --gpus=device=5 mlcube run --mlcube ./mlcube_col5 --task train -e MEDPERF_PARTICIPANT_LABEL=col5@example.com" # medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 77456c4de..5f2d58ec6 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -57,6 +57,7 @@ def seed_everything(seed=1234): def train_nnunet(TOTAL_max_num_epochs, epochs, current_epoch, + decrement_current_epoch_by_one=False, train_cutoff=np.inf, val_cutoff=np.inf, network='3d_fullres', @@ -84,6 +85,7 @@ def train_nnunet(TOTAL_max_num_epochs, TOTAL_max_num_epochs (int): Provides the total number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during the course of federated training) epochs (int): Number of epochs to trainon top of current epoch current_epoch (int): Which epoch will be used to grab the model + decrement_current_epoch_by_one (bool) : Whether or not to reduce the trainer epoch value by one after calling this function in order to offset increment already in function (used in validation only scenario) train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training val_cutoff_part (float): Portion of train_val_cutoff going to val @@ -231,6 +233,8 @@ def __init__(self, **kwargs): trainer.save_latest_only = ( True # if false it will not store/overwrite _latest but separate files each ) + + # we will reset this to old epoch value before saving a new checkpoint within run_training call if decrement_current_epoch_by_one is True trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch @@ -268,7 +272,7 @@ def __init__(self, **kwargs): # new training without pretraine weights, do nothing pass - batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff) + batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, decrement_current_epoch_by_one=decrement_current_epoch_by_one) else: # if valbest: # trainer.load_best_checkpoint(train=False) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index b1d4d5c54..7bade89a6 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -185,8 +185,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - # offseting current_epoch in order that it is set back to where it was in previous checkpoint after the rain call - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - 1 + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) # 2. Train/val function existing externally # Some todo inside function below # TODO: test for off-by-one error @@ -201,7 +200,8 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): current_epoch=current_epoch, train_cutoff=0, val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name()) + task=self.data_loader.get_task_name(), + decrement_current_epoch_by_one=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") From ac18c19413288cb1c3da18b67b43845da0b5feef Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 17:07:34 -0700 Subject: [PATCH 152/242] some debug statements --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 7bade89a6..9aff9a828 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -149,6 +149,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"In col train method, loaded checkpoint with current epoch: {current_epoch}") # 2. Train/val function existing externally # Some todo inside function below # TODO: test for off-by-one error @@ -158,7 +159,6 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Should we put this in a separate process? # TODO: Currently allowing at most 1 second of valiation over one batch in order to avoid NNUnet code throwing exception due # to empty val results - print(f"Brandon DEBUG - about to call train_nnunet with:\nTOTAL_max_num_epochs:{self.TOTAL_max_num_epochs}\ntrain_cutoff:{self.train_cutoff}\nval_cutoff:1") train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, @@ -166,7 +166,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): val_cutoff = 1, task=self.data_loader.get_task_name()) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work.") + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # 3. Load metrics from checkpoint (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] @@ -186,6 +186,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}") # 2. Train/val function existing externally # Some todo inside function below # TODO: test for off-by-one error @@ -193,8 +194,6 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - print(f"Brandon DEBUG - about to call train_nnunet with:\nTOTAL_max_num_epochs:{self.TOTAL_max_num_epochs}\ntrain_cutoff:0\nval_cutoff:{self.val_cutoff}") - print(f"Recall that you may be getting in trouble here due to train_cutoff exiting without val work being done?") train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=1, current_epoch=current_epoch, From 0fc0aa0a76fedf6d1876b5414196237c95361b73 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 18:06:45 -0700 Subject: [PATCH 153/242] some debug statements --- examples/fl_post/fl/project/src/nnunet_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 5f2d58ec6..7a6a75f60 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -271,7 +271,7 @@ def __init__(self, **kwargs): else: # new training without pretraine weights, do nothing pass - + print(f"Brandon DEBUG - Calling trainer.run_training, trainer epoch: {trainer.epoch}, trainer max_num_epochs:{trainer.max_num_epochs}") batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, decrement_current_epoch_by_one=decrement_current_epoch_by_one) else: # if valbest: From ff8eb556ed8b0ab8ec90f1ee4cd31295fa521f66 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 18:28:09 -0700 Subject: [PATCH 154/242] some more debug statements --- examples/fl_post/fl/project/src/nnunet_v1.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 7a6a75f60..5973fb870 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -238,9 +238,13 @@ def __init__(self, **kwargs): trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch + print(f"Brandon DEBUG - about to initialize trainer, currently t.max_num:{trainer.max_num_epochs}, t.epo:{trainer.epoch}") + # TODO: call validation separately trainer.initialize(not validation_only) + print(f"Brandon DEBUG - after initialize trainer, currently t.max_num:{trainer.max_num_epochs}, t.epo:{trainer.epoch}") + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times @@ -272,6 +276,8 @@ def __init__(self, **kwargs): # new training without pretraine weights, do nothing pass print(f"Brandon DEBUG - Calling trainer.run_training, trainer epoch: {trainer.epoch}, trainer max_num_epochs:{trainer.max_num_epochs}") + print(f"Brandon DEBUG - NOTE: this is where I had just loaded checkpoint.") + batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, decrement_current_epoch_by_one=decrement_current_epoch_by_one) else: # if valbest: From daaad2b4883eabcbdc49b19ffe786413a721a2b7 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 3 Oct 2024 19:43:18 -0700 Subject: [PATCH 155/242] New handling of epoch in checkpoint and new handling of lr scheduling, both using the param val_epoch --- examples/fl_post/fl/project/src/nnunet_v1.py | 7 +++---- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 5973fb870..b26269a59 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -57,7 +57,7 @@ def seed_everything(seed=1234): def train_nnunet(TOTAL_max_num_epochs, epochs, current_epoch, - decrement_current_epoch_by_one=False, + val_epoch=False, train_cutoff=np.inf, val_cutoff=np.inf, network='3d_fullres', @@ -85,7 +85,7 @@ def train_nnunet(TOTAL_max_num_epochs, TOTAL_max_num_epochs (int): Provides the total number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during the course of federated training) epochs (int): Number of epochs to trainon top of current epoch current_epoch (int): Which epoch will be used to grab the model - decrement_current_epoch_by_one (bool) : Whether or not to reduce the trainer epoch value by one after calling this function in order to offset increment already in function (used in validation only scenario) + val_epoch (bool) : Used in validation only scenario, makes lr scheduler not step and epoch to not incement upon saving final checkpoint train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training val_cutoff_part (float): Portion of train_val_cutoff going to val @@ -234,7 +234,6 @@ def __init__(self, **kwargs): True # if false it will not store/overwrite _latest but separate files each ) - # we will reset this to old epoch value before saving a new checkpoint within run_training call if decrement_current_epoch_by_one is True trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch @@ -278,7 +277,7 @@ def __init__(self, **kwargs): print(f"Brandon DEBUG - Calling trainer.run_training, trainer epoch: {trainer.epoch}, trainer max_num_epochs:{trainer.max_num_epochs}") print(f"Brandon DEBUG - NOTE: this is where I had just loaded checkpoint.") - batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, decrement_current_epoch_by_one=decrement_current_epoch_by_one) + batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, val_epoch=val_epoch) else: # if valbest: # trainer.load_best_checkpoint(train=False) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 9aff9a828..59412bcc8 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -200,7 +200,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): train_cutoff=0, val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), - decrement_current_epoch_by_one=True) + val_epoch=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") From 5fee48a2b12359959722486d04fe4f80da73c8c4 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 11:42:04 -0700 Subject: [PATCH 156/242] passing self.TOTAL_max_num_epochs to lr scheduler computation --- examples/fl_post/fl/project/src/nnunet_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index b26269a59..3754901ec 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -262,7 +262,7 @@ def __init__(self, **kwargs): return if find_lr: - trainer.find_lr() + trainer.find_lr(num_iters=self.TOTAL_max_num_epochs) else: if not validation_only: if args.continue_training: From 58cedc16d822baf141e5231abb7d07f157de68bb Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 12:11:04 -0700 Subject: [PATCH 157/242] will try absolutely no train/val in val/train cases respectively --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 59412bcc8..c9023fa35 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -163,7 +163,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, - val_cutoff = 1, + val_cutoff = 0, task=self.data_loader.get_task_name()) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") From 21d0c0e98aa3e5c8e889c73fc4718ad47a36c7a0 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 16:29:31 -0700 Subject: [PATCH 158/242] new changes to keep global results out of checkpoints, returning metrics directly from nnunet train call, should help to keep checkpoints cleaner and nnunet code from throwing exceptions --- .../fl/mlcube/workspace/training_config.yaml | 2 + examples/fl_post/fl/project/src/nnunet_v1.py | 13 ++++- .../fl_post/fl/project/src/runner_nnunetv1.py | 49 +++++++++++-------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 5538220c8..c1c65923e 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -59,6 +59,7 @@ tasks : metrics : - val_eval apply : global + val_results_to_checkpoint : false train: function : train kwargs : @@ -71,6 +72,7 @@ tasks : metrics : - val_eval apply : local + val_results_to_checkpoint : true compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 3754901ec..437bc577b 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -58,6 +58,7 @@ def train_nnunet(TOTAL_max_num_epochs, epochs, current_epoch, val_epoch=False, + val_results_to_checkpoint=False, train_cutoff=np.inf, val_cutoff=np.inf, network='3d_fullres', @@ -86,6 +87,7 @@ def train_nnunet(TOTAL_max_num_epochs, epochs (int): Number of epochs to trainon top of current epoch current_epoch (int): Which epoch will be used to grab the model val_epoch (bool) : Used in validation only scenario, makes lr scheduler not step and epoch to not incement upon saving final checkpoint + val_results_to_checkpoint (bool) : Whether or not to store the val results in a class attribute that will then land in the checkpoint (we will only store local val in checkpoints) train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training val_cutoff_part (float): Portion of train_val_cutoff going to val @@ -277,7 +279,14 @@ def __init__(self, **kwargs): print(f"Brandon DEBUG - Calling trainer.run_training, trainer epoch: {trainer.epoch}, trainer max_num_epochs:{trainer.max_num_epochs}") print(f"Brandon DEBUG - NOTE: this is where I had just loaded checkpoint.") - batches_applied_train, batches_applied_val = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, val_epoch=val_epoch) + batches_applied_train, \ + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics = trainer.run_training(train_cutoff=train_cutoff, + val_cutoff=val_cutoff, + val_epoch=val_epoch, + val_results_to_checkpoint=val_results_to_checkpoint) else: # if valbest: # trainer.load_best_checkpoint(train=False) @@ -288,6 +297,6 @@ def __init__(self, **kwargs): train_completed = batches_applied_train / float(num_train_batches_per_epoch) val_completed = batches_applied_val / float(num_val_batches_per_epoch) - return train_completed, val_completed + return train_completed, val_completed, this_ave_train_loss, this_ave_val_loss, this_val_eval_metrics diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index c9023fa35..efc208a46 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -159,19 +159,26 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Should we put this in a separate process? # TODO: Currently allowing at most 1 second of valiation over one batch in order to avoid NNUnet code throwing exception due # to empty val results - train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, val_cutoff = 0, - task=self.data_loader.get_task_name()) - + task=self.data_loader.get_task_name() + val_epoch=False + val_results_to_checkpoint=False) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - # 3. Load metrics from checkpoint - (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkpoint each call to train_nnunet, so it is critical that we are grabbing this right after the call above - metrics = {'train_loss': all_tr_losses[-1]} + # double check + if val_completed != 0.0: + raise ValueError(f"Tried to train only, but got a non-zero amount ({val_completed}) of validation done.") + + # 3. Prepare metrics + metrics = {'train_loss': this_ave_train_loss} ###################################################################################################### # TODO: Provide train_completed to be incorporated into the collab weight computation @@ -179,7 +186,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + def validate(self, col_name, round_num, input_tensor_dict,val_results_to_checkpoint, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" @@ -194,24 +201,26 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - train_completed, val_completed = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, - epochs=1, - current_epoch=current_epoch, - train_cutoff=0, - val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name(), - val_epoch=True) - + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + epochs=1, + current_epoch=current_epoch, + train_cutoff=0, + val_cutoff = self.val_cutoff, + task=self.data_loader.get_task_name(), + val_epoch=True, + val_results_to_checkpoint=val_results_to_checkpoint) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # double check if train_completed != 0.0: raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") - # 3. Load metrics from checkpoint - (_, all_val_losses, _, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkopint each call to train_nnunet, so it is critical that we are grabbing this right after the call above - metrics = {'val_eval': all_val_eval_metrics[-1]} + # 3. Prepare metrics + metrics = {'val_eval': this_val_eval_metrics} ###################################################################################################### From 47c76370586671294810e53ce874f28ea64588e0 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 17:01:36 -0700 Subject: [PATCH 159/242] fixing syntax error --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index efc208a46..a9d1eb911 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -161,7 +161,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # to empty val results train_completed, \ val_completed, \ - this_ave_train_loss, \ + this_ave_train_loss, \ this_ave_val_loss, \ this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, From 7aaa31a00921a989a53410b4af7521a93163a3fa Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 17:17:46 -0700 Subject: [PATCH 160/242] syntax --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index a9d1eb911..ac008c129 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -168,8 +168,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): current_epoch=current_epoch, train_cutoff=self.train_cutoff, val_cutoff = 0, - task=self.data_loader.get_task_name() - val_epoch=False + task=self.data_loader.get_task_name(), + val_epoch=False, val_results_to_checkpoint=False) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") From 3d2871bc4e87e47da63972e58ddb06426f09f2e4 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 19:19:36 -0700 Subject: [PATCH 161/242] function for local val now only grabs metrics from checkpoint --- .../fl_post/fl/mlcube/workspace/training_config.yaml | 3 +-- examples/fl_post/fl/project/src/runner_nnunetv1.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index c1c65923e..5adcd0c50 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -67,12 +67,11 @@ tasks : - train_loss epochs : 1 locally_tuned_model_validation: - function : validate + function : validate_by_reading_checkpoint kwargs : metrics : - val_eval apply : local - val_results_to_checkpoint : true compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index ac008c129..4af00e57b 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -238,4 +238,14 @@ def load_metrics(self, filepath): with open(filepath) as json_file: metrics = json.load(json_file) return metrics - """ \ No newline at end of file + """ + + + # TODO here, save train_completed and val_completed as class attributes + # WORKING HERE + + def validate_by_reading_checkpoint(self, col_name, round_num, input_tensor_dict, **kwargs): + (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after + metrics = {'train_loss': all_tr_losses[-1]} + return self.convert_results_to_tensorkeys(col_name, round_num, metrics) \ No newline at end of file From 2682d173b5043177b83c4d8b82a962b302fdc3f4 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 7 Oct 2024 20:02:17 -0700 Subject: [PATCH 162/242] having now to designate both val_epoch and train_epoch booleans at the nnunet train function level --- examples/fl_post/fl/project/src/nnunet_v1.py | 9 ++++++--- examples/fl_post/fl/project/src/runner_nnunetv1.py | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 437bc577b..6206cd169 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -57,7 +57,8 @@ def seed_everything(seed=1234): def train_nnunet(TOTAL_max_num_epochs, epochs, current_epoch, - val_epoch=False, + val_epoch=True, + train_epoch=True, val_results_to_checkpoint=False, train_cutoff=np.inf, val_cutoff=np.inf, @@ -86,7 +87,8 @@ def train_nnunet(TOTAL_max_num_epochs, TOTAL_max_num_epochs (int): Provides the total number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during the course of federated training) epochs (int): Number of epochs to trainon top of current epoch current_epoch (int): Which epoch will be used to grab the model - val_epoch (bool) : Used in validation only scenario, makes lr scheduler not step and epoch to not incement upon saving final checkpoint + val_epoch (bool) : Will validation be performed + train_epoch (bool) : Will training run (rather than val only) makes lr step and epoch increment val_results_to_checkpoint (bool) : Whether or not to store the val results in a class attribute that will then land in the checkpoint (we will only store local val in checkpoints) train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training @@ -285,7 +287,8 @@ def __init__(self, **kwargs): this_ave_val_loss, \ this_val_eval_metrics = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, - val_epoch=val_epoch, + val_epoch=val_epoch, + train_epoch=train_epoch, val_results_to_checkpoint=val_results_to_checkpoint) else: # if valbest: diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 4af00e57b..7fbdcee89 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -167,10 +167,11 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, - val_cutoff = 0, + val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), - val_epoch=False, - val_results_to_checkpoint=False) + val_epoch=True, + train_epoch=True, + val_results_to_checkpoint=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # double check @@ -212,6 +213,7 @@ def validate(self, col_name, round_num, input_tensor_dict,val_results_to_checkpo val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), val_epoch=True, + train_epoch=False, val_results_to_checkpoint=val_results_to_checkpoint) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") From 19c3dff6284a24178d4e4a2420a859a2f1a82e84 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 08:29:11 -0700 Subject: [PATCH 163/242] removing variable no longer used --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 1 - examples/fl_post/fl/project/src/nnunet_v1.py | 5 +---- examples/fl_post/fl/project/src/runner_nnunetv1.py | 8 +++----- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 5adcd0c50..8c18249e5 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -59,7 +59,6 @@ tasks : metrics : - val_eval apply : global - val_results_to_checkpoint : false train: function : train kwargs : diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 6206cd169..3f117e8a0 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -59,7 +59,6 @@ def train_nnunet(TOTAL_max_num_epochs, current_epoch, val_epoch=True, train_epoch=True, - val_results_to_checkpoint=False, train_cutoff=np.inf, val_cutoff=np.inf, network='3d_fullres', @@ -89,7 +88,6 @@ def train_nnunet(TOTAL_max_num_epochs, current_epoch (int): Which epoch will be used to grab the model val_epoch (bool) : Will validation be performed train_epoch (bool) : Will training run (rather than val only) makes lr step and epoch increment - val_results_to_checkpoint (bool) : Whether or not to store the val results in a class attribute that will then land in the checkpoint (we will only store local val in checkpoints) train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training val_cutoff_part (float): Portion of train_val_cutoff going to val @@ -288,8 +286,7 @@ def __init__(self, **kwargs): this_val_eval_metrics = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, val_epoch=val_epoch, - train_epoch=train_epoch, - val_results_to_checkpoint=val_results_to_checkpoint) + train_epoch=train_epoch) else: # if valbest: # trainer.load_best_checkpoint(train=False) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 7fbdcee89..c3b60937d 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -170,8 +170,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), val_epoch=True, - train_epoch=True, - val_results_to_checkpoint=True) + train_epoch=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # double check @@ -187,7 +186,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - def validate(self, col_name, round_num, input_tensor_dict,val_results_to_checkpoint, **kwargs): + def validate(self, col_name, round_num, input_tensor_dict, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" @@ -213,8 +212,7 @@ def validate(self, col_name, round_num, input_tensor_dict,val_results_to_checkpo val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), val_epoch=True, - train_epoch=False, - val_results_to_checkpoint=val_results_to_checkpoint) + train_epoch=False) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # double check From 059d1264cea11c20c69596f438899e3c030159ab Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 09:30:53 -0700 Subject: [PATCH 164/242] removing check for no val work done during train call, no longer the case --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index c3b60937d..299e13e4c 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -173,10 +173,6 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): train_epoch=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - # double check - if val_completed != 0.0: - raise ValueError(f"Tried to train only, but got a non-zero amount ({val_completed}) of validation done.") - # 3. Prepare metrics metrics = {'train_loss': this_ave_train_loss} From cda5de9f82ce948d63522df22174f93a5de431bc Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 10:24:41 -0700 Subject: [PATCH 165/242] moving validation from checkpoint and not both under function 'validate' --- .../fl_post/fl/project/src/runner_nnunetv1.py | 86 +++++++++++-------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 299e13e4c..17a0624b1 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -182,42 +182,58 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" - self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) - # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}") - # 2. Train/val function existing externally - # Some todo inside function below - # TODO: test for off-by-one error - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? - train_completed, \ - val_completed, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, - epochs=1, - current_epoch=current_epoch, - train_cutoff=0, - val_cutoff = self.val_cutoff, - task=self.data_loader.get_task_name(), - val_epoch=True, - train_epoch=False) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - - # double check - if train_completed != 0.0: - raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") - - # 3. Prepare metrics - metrics = {'val_eval': this_val_eval_metrics} - + def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): + hash_1 = np.sum[torch.mean(_value) for _value in td_1.values()] + hash_2 = np.sum[torch.mean(_value) for _value in td_2.values()] + delta = np.abs(hash_1 - hash_2) + if delta > epsilon: + raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") + + + if not from_checkpoint: + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}") + # 2. Train/val function existing externally + # Some todo inside function below + # TODO: test for off-by-one error + + # FIXME: we need to understand how to use round_num instead of current_epoch + # this will matter in straggler handling cases + # TODO: Should we put this in a separate process? + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + epochs=1, + current_epoch=current_epoch, + train_cutoff=0, + val_cutoff = self.val_cutoff, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=False) + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # 3. Prepare metrics + metrics = {'val_eval': this_val_eval_metrics} + else: + checkpoint_dict = self.load_checkpoint() + # double check + compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict']) + + (all_tr_losses, _, _, _) = checkpoint_dict['plot_stuff'] + # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after + metrics = {'train_loss': all_tr_losses[-1]} ###################################################################################################### # TODO: Provide val_completed to be incorporated into the collab weight computation @@ -241,7 +257,5 @@ def load_metrics(self, filepath): # WORKING HERE def validate_by_reading_checkpoint(self, col_name, round_num, input_tensor_dict, **kwargs): - (all_tr_losses, _, _, _) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after - metrics = {'train_loss': all_tr_losses[-1]} + fjkdls;jafkdls;jfkdsl; return self.convert_results_to_tensorkeys(col_name, round_num, metrics) \ No newline at end of file From 24fe1b2d62580c3e5ee7d05ddd2fceec11714e06 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 10:27:47 -0700 Subject: [PATCH 166/242] adding single val funciton to training config --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 8c18249e5..eaeee2372 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -66,11 +66,12 @@ tasks : - train_loss epochs : 1 locally_tuned_model_validation: - function : validate_by_reading_checkpoint + function : validate kwargs : metrics : - val_eval apply : local + from_checkpoint: true compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml From d7d838eaa36fd67386319645df601ee743fa458f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 10:41:13 -0700 Subject: [PATCH 167/242] syntax --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 17a0624b1..7e76f298f 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -187,8 +187,8 @@ def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False """Perform validation.""" def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): - hash_1 = np.sum[torch.mean(_value) for _value in td_1.values()] - hash_2 = np.sum[torch.mean(_value) for _value in td_2.values()] + hash_1 = np.sum([torch.mean(_value) for _value in td_1.values()]) + hash_2 = np.sum([torch.mean(_value) for _value in td_2.values()]) delta = np.abs(hash_1 - hash_2) if delta > epsilon: raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") From 91a95232d11b7653317cb649c4f30921d564c70f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 11:21:12 -0700 Subject: [PATCH 168/242] replaced np.sum for torch.sum, and preparing for per_task data size using amount of work completed --- .../fl_post/fl/project/src/runner_nnunetv1.py | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 7e76f298f..520de9aee 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -81,6 +81,13 @@ def __init__(self, self.val_cutoff = val_cutoff self.config_path = config_path self.TOTAL_max_num_epochs=TOTAL_max_num_epochs + + # self.task_completed is a dictionary of task to amount completed as a float in [0,1] + # Values will be dynamically updated + # TODO: Tasks are hard coded for now + self.task_completed = {'aggregated_model_validation': 1.0, + 'train': 1.0, + 'locally_tuned_model_validation': 1.0} def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): @@ -171,6 +178,10 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): task=self.data_loader.get_task_name(), val_epoch=True, train_epoch=True) + # update amount of task completed + self.task_completed['train'] = train_completed + self.task_completed['locally_tuned_model_validation'] = val_completed + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # 3. Prepare metrics @@ -187,8 +198,8 @@ def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False """Perform validation.""" def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): - hash_1 = np.sum([torch.mean(_value) for _value in td_1.values()]) - hash_2 = np.sum([torch.mean(_value) for _value in td_2.values()]) + hash_1 = np.sum([np.mean(_value) for _value in td_1.values()]) + hash_2 = np.sum([np.mean(_value) for _value in td_2.values()]) delta = np.abs(hash_1 - hash_2) if delta > epsilon: raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") @@ -218,12 +229,16 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): task=self.data_loader.get_task_name(), val_epoch=True, train_epoch=False) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - # double check if train_completed != 0.0: raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + # update amount of task completed + self.task_completed['aggregated_model_validation'] = val_completed + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + # 3. Prepare metrics metrics = {'val_eval': this_val_eval_metrics} else: @@ -253,9 +268,23 @@ def load_metrics(self, filepath): """ - # TODO here, save train_completed and val_completed as class attributes - # WORKING HERE + # TODO to support below, save train_completed and val_completed as class attributes + # WORKING HERE, for now turned off due to task_dependent default - def validate_by_reading_checkpoint(self, col_name, round_num, input_tensor_dict, **kwargs): - fjkdls;jafkdls;jfkdsl; - return self.convert_results_to_tensorkeys(col_name, round_num, metrics) \ No newline at end of file + def get_train_data_size(self, task_dependent=False, task=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples. + """ + if not task_dependent: + return self.data_loader.get_train_data_size() + elif not task: + raise ValueError(f"If using task dependent data size, must provide task.") + else: + # self.task_completed is a dictionary of task to amount completed as a float in [0,1] + return self.task_completed[task] * self.data_loader.get_train_data_size() From 9b2f8ee49812b55de49d14daf70fddeab52321be Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 12:34:00 -0700 Subject: [PATCH 169/242] was using np on tensors --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 520de9aee..fd986c97d 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -198,8 +198,8 @@ def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False """Perform validation.""" def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): - hash_1 = np.sum([np.mean(_value) for _value in td_1.values()]) - hash_2 = np.sum([np.mean(_value) for _value in td_2.values()]) + hash_1 = np.sum([torch.mean(_value) for _value in td_1.values()]) + hash_2 = np.sum([torch.mean(_value) for _value in td_2.values()]) delta = np.abs(hash_1 - hash_2) if delta > epsilon: raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") From dc495eab80951fc2c398a3de78c780545e3a1077 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 12:51:39 -0700 Subject: [PATCH 170/242] casting to numpy (some are already, and some are torch tensors) --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index fd986c97d..09ba59d83 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -198,8 +198,8 @@ def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False """Perform validation.""" def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): - hash_1 = np.sum([torch.mean(_value) for _value in td_1.values()]) - hash_2 = np.sum([torch.mean(_value) for _value in td_2.values()]) + hash_1 = np.sum([np.mean(np.array(_value)) for _value in td_1.values()]) + hash_2 = np.sum([np.mean(np.array(_value)) for _value in td_2.values()]) delta = np.abs(hash_1 - hash_2) if delta > epsilon: raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") From e65eeef38c8aa0ad13cbb7d4aedbe2f78def3f85 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 14:48:40 -0700 Subject: [PATCH 171/242] typo --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 09ba59d83..781c9e88f 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -202,7 +202,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): hash_2 = np.sum([np.mean(np.array(_value)) for _value in td_2.values()]) delta = np.abs(hash_1 - hash_2) if delta > epsilon: - raise VaueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") + raise ValueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") if not from_checkpoint: @@ -244,7 +244,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): else: checkpoint_dict = self.load_checkpoint() # double check - compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict']) + compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") (all_tr_losses, _, _, _) = checkpoint_dict['plot_stuff'] # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after From aaf61fa0c5708273ef66cd7756d37631c4a2e744 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 15:06:33 -0700 Subject: [PATCH 172/242] setting checkpoint delta error to 0.1 for now as well as verbose printing of delta --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 781c9e88f..d89185c28 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -197,10 +197,12 @@ def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" - def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.0001): + def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): hash_1 = np.sum([np.mean(np.array(_value)) for _value in td_1.values()]) hash_2 = np.sum([np.mean(np.array(_value)) for _value in td_2.values()]) delta = np.abs(hash_1 - hash_2) + if verbose: + print(f"The tensor dict comparison {tag} resulted in delta: {delta} (accepted error: {epsilon}).") if delta > epsilon: raise ValueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") From 766c177530e23a0696b3abe3da351debbf15a55a Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 8 Oct 2024 15:28:45 -0700 Subject: [PATCH 173/242] correcting metric that had train loss instead of val metric --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index d89185c28..4a60c931d 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -248,9 +248,12 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): # double check compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") - (all_tr_losses, _, _, _) = checkpoint_dict['plot_stuff'] + all_tr_losses, \ + all_val_losses, \ + all_val_losses_tr_mode, \ + all_val_eval_metrics = checkpoint_dict['plot_stuff'] # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after - metrics = {'train_loss': all_tr_losses[-1]} + metrics = {'val_eval': all_val_eval_metrics[-1]} ###################################################################################################### # TODO: Provide val_completed to be incorporated into the collab weight computation From 67e346738296917270f5467c84e6b49dc92be422 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 10 Oct 2024 12:12:03 -0700 Subject: [PATCH 174/242] removing some testing output --- examples/fl_post/fl/project/src/nnunet_v1.py | 7 ------- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 3f117e8a0..ffab4f131 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -239,13 +239,8 @@ def __init__(self, **kwargs): trainer.max_num_epochs = current_epoch + epochs trainer.epoch = current_epoch - print(f"Brandon DEBUG - about to initialize trainer, currently t.max_num:{trainer.max_num_epochs}, t.epo:{trainer.epoch}") - - # TODO: call validation separately trainer.initialize(not validation_only) - print(f"Brandon DEBUG - after initialize trainer, currently t.max_num:{trainer.max_num_epochs}, t.epo:{trainer.epoch}") - # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times @@ -276,8 +271,6 @@ def __init__(self, **kwargs): else: # new training without pretraine weights, do nothing pass - print(f"Brandon DEBUG - Calling trainer.run_training, trainer epoch: {trainer.epoch}, trainer max_num_epochs:{trainer.max_num_epochs}") - print(f"Brandon DEBUG - NOTE: this is where I had just loaded checkpoint.") batches_applied_train, \ batches_applied_val, \ diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 4a60c931d..3cb52de50 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -245,8 +245,8 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): metrics = {'val_eval': this_val_eval_metrics} else: checkpoint_dict = self.load_checkpoint() - # double check - compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") + # double check uncomment below for testing + # compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") all_tr_losses, \ all_val_losses, \ From 478865929ecc02a7e17175621f478d8383156930 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 10 Oct 2024 13:03:45 -0700 Subject: [PATCH 175/242] enabling task dependent data size --- .../fl_post/fl/project/src/runner_nnunetv1.py | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 3cb52de50..0f762ad5a 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -33,7 +33,7 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_cutoff=np.inf, + train_cutoff=3, val_cutoff=np.inf, nnunet_task=None, config_path=None, @@ -178,6 +178,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): task=self.data_loader.get_task_name(), val_epoch=True, train_epoch=True) + # update amount of task completed self.task_completed['train'] = train_completed self.task_completed['locally_tuned_model_validation'] = val_completed @@ -187,9 +188,6 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # 3. Prepare metrics metrics = {'train_loss': this_ave_train_loss} - ###################################################################################################### - # TODO: Provide train_completed to be incorporated into the collab weight computation - ###################################################################################################### return self.convert_results_to_tensorkeys(col_name, round_num, metrics) @@ -255,9 +253,6 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after metrics = {'val_eval': all_val_eval_metrics[-1]} - ###################################################################################################### - # TODO: Provide val_completed to be incorporated into the collab weight computation - ###################################################################################################### return self.convert_results_to_tensorkeys(col_name, round_num, metrics) @@ -273,10 +268,7 @@ def load_metrics(self, filepath): """ - # TODO to support below, save train_completed and val_completed as class attributes - # WORKING HERE, for now turned off due to task_dependent default - - def get_train_data_size(self, task_dependent=False, task=None): + def get_train_data_size(self, task_dependent=False, task_name=None): """Get the number of training examples. It will be used for weighted averaging in aggregation. @@ -284,12 +276,31 @@ def get_train_data_size(self, task_dependent=False, task=None): allowing dynamic weighting by storing recent appropriate weights in class attributes. Returns: - int: The number of training examples. + float: The number of training examples, weighted by how much of the task got completed. """ if not task_dependent: return self.data_loader.get_train_data_size() - elif not task: - raise ValueError(f"If using task dependent data size, must provide task.") + elif not task_name: + raise ValueError(f"If using task dependent data size, must provide task_name.") + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return self.task_completed[task_name] * self.data_loader.get_train_data_size() + + + def get_valid_data_size(self, task_dependent=False, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + float: The number of training examples, weighted by how much of the task got completed. + """ + if not task_dependent: + return self.data_loader.get_valid_data_size() + elif not task_name: + raise ValueError(f"If using task dependent data size, must provide task_name.") else: - # self.task_completed is a dictionary of task to amount completed as a float in [0,1] - return self.task_completed[task] * self.data_loader.get_train_data_size() + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return self.task_completed[task_name] * self.data_loader.get_valid_data_size() From 75e14e1c3ef5c2e5d636b29cf0a52e4b4d510a9a Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 10 Oct 2024 14:52:34 -0700 Subject: [PATCH 176/242] commenting out debug print --- examples/fl_post/fl/project/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/hooks.py b/examples/fl_post/fl/project/hooks.py index ff1cda9ee..ab46e0f1e 100644 --- a/examples/fl_post/fl/project/hooks.py +++ b/examples/fl_post/fl/project/hooks.py @@ -54,7 +54,7 @@ def collaborator_pre_training_hook( # when evan runs, init_model_path, init_model_info_path should be None # plans_path should also be None (the returned thing will point to where it lives so that it will be synced with others) - print(f"Brandon DEBUG - postopp_pardir will be pointed to: {workspace_folder} which has data subfolder containing: {os.listdir(os.path.join(workspace_folder, 'data'))}") + # print(f"Brandon DEBUG - postopp_pardir will be pointed to: {workspace_folder} which has data subfolder containing: {os.listdir(os.path.join(workspace_folder, 'data'))}") nnunet_setup.main(postopp_pardir=workspace_folder, three_digit_task_num=537, # FIXME: does this need to be set in any particular way? From 3ebfc6ba2b4b0faba632f8a9224158d99a84d368 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 10 Oct 2024 14:53:32 -0700 Subject: [PATCH 177/242] adding recent changes to intel_build for local openfl changes --- examples/fl_post/fl/intel_build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/intel_build.sh b/examples/fl_post/fl/intel_build.sh index 54494127a..7f8cd19f7 100755 --- a/examples/fl_post/fl/intel_build.sh +++ b/examples/fl_post/fl/intel_build.sh @@ -11,7 +11,7 @@ cp ./project/be_Dockerfile ./project/Dockerfile if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 54f27c61c274f64af3d028f962f62392419cb67e + git checkout be_enable_partial_epochs docker build \ --build-arg http_proxy="http://proxy-us.intel.com:912" \ --build-arg https_proxy="http://proxy-us.intel.com:912" \ From 18764cc8648da08b7e0387b60a4e3c67b9023700 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 10 Oct 2024 14:59:30 -0700 Subject: [PATCH 178/242] ensuring get_data_size returns an int so as to satisfy proto schema --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 0f762ad5a..b4c249146 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -276,7 +276,7 @@ def get_train_data_size(self, task_dependent=False, task_name=None): allowing dynamic weighting by storing recent appropriate weights in class attributes. Returns: - float: The number of training examples, weighted by how much of the task got completed. + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema """ if not task_dependent: return self.data_loader.get_train_data_size() @@ -284,7 +284,7 @@ def get_train_data_size(self, task_dependent=False, task_name=None): raise ValueError(f"If using task dependent data size, must provide task_name.") else: # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] - return self.task_completed[task_name] * self.data_loader.get_train_data_size() + return int(np.ceil(self.task_completed[task_name] * self.data_loader.get_train_data_size())) def get_valid_data_size(self, task_dependent=False, task_name=None): @@ -295,7 +295,7 @@ def get_valid_data_size(self, task_dependent=False, task_name=None): allowing dynamic weighting by storing recent appropriate weights in class attributes. Returns: - float: The number of training examples, weighted by how much of the task got completed. + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema """ if not task_dependent: return self.data_loader.get_valid_data_size() @@ -303,4 +303,4 @@ def get_valid_data_size(self, task_dependent=False, task_name=None): raise ValueError(f"If using task dependent data size, must provide task_name.") else: # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] - return self.task_completed[task_name] * self.data_loader.get_valid_data_size() + return int(np.ceil(self.task_completed[task_name] * self.data_loader.get_valid_data_size())) From aa2d300f8a5dab37bd6667e8088c795c9f4545c9 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 10 Oct 2024 18:05:13 -0700 Subject: [PATCH 179/242] changing test setting --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index b4c249146..f4ce2c182 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -33,7 +33,7 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_cutoff=3, + train_cutoff=100, val_cutoff=np.inf, nnunet_task=None, config_path=None, From 0dc6eea1559b57ab697ce6f9c50c19fdb2a9fb54 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 11 Oct 2024 12:02:54 -0700 Subject: [PATCH 180/242] only including model in task results in train task (not local or global val) --- .../fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- .../fl_post/fl/project/src/runner_pt_chkpt.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index f4ce2c182..b560cff85 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -188,7 +188,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # 3. Prepare metrics metrics = {'train_loss': this_ave_train_loss} - return self.convert_results_to_tensorkeys(col_name, round_num, metrics) + return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False, **kwargs): @@ -253,7 +253,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after metrics = {'val_eval': all_val_eval_metrics[-1]} - return self.convert_results_to_tensorkeys(col_name, round_num, metrics) + return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) def load_metrics(self, filepath): diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index 6ab7851b9..0c4456b53 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -255,7 +255,9 @@ def _read_opt_state_from_checkpoint(self, checkpoint_dict): return derived_opt_state_dict - def convert_results_to_tensorkeys(self, col_name, round_num, metrics): + def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_model): + # insert_model determined whether or not to include the model in the return dictionaries + # 5. Convert to tensorkeys # output metric tensors (scalar) @@ -268,11 +270,14 @@ def convert_results_to_tensorkeys(self, col_name, round_num, metrics): metrics[metric_name] ) for metric_name in metrics} - # output model tensors (Doesn't include TensorKey) - output_model_dict = self.get_tensor_dict(with_opt_vars=True) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, - tensor_dict=output_model_dict, - **self.tensor_dict_split_fn_kwargs) + if include_model: + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict(with_opt_vars=True) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, + tensor_dict=output_model_dict, + **self.tensor_dict_split_fn_kwargs) + else: + global_model_dict, local_model_dict = {}, {} # create global tensorkeys global_tensorkey_model_dict = { From 3ecb0781c24c3e687f8adeddab0d345522483cbb Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 11 Oct 2024 13:10:01 -0700 Subject: [PATCH 181/242] inconsistency in new variable name --- examples/fl_post/fl/project/src/runner_pt_chkpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index 0c4456b53..ac7568c20 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -270,7 +270,7 @@ def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_mod metrics[metric_name] ) for metric_name in metrics} - if include_model: + if insert_model: # output model tensors (Doesn't include TensorKey) output_model_dict = self.get_tensor_dict(with_opt_vars=True) global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, From c7cdd8f8b4e9bad5ec66281f1a0125e172e4070f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 11 Oct 2024 16:14:04 -0700 Subject: [PATCH 182/242] now including per label DICE for global and local val --- .../fl/mlcube/workspace/training_config.yaml | 8 +++++ .../fl_post/fl/project/src/runner_nnunetv1.py | 30 +++++++++++++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index eaeee2372..d6029fa8b 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -58,6 +58,10 @@ tasks : kwargs : metrics : - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 apply : global train: function : train @@ -70,6 +74,10 @@ tasks : kwargs : metrics : - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 apply : local from_checkpoint: true diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index b560cff85..8de8fbc56 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -170,7 +170,11 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): val_completed, \ this_ave_train_loss, \ this_ave_val_loss, \ - this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, @@ -221,7 +225,11 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): val_completed, \ this_ave_train_loss, \ this_ave_val_loss, \ - this_val_eval_metrics = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=1, current_epoch=current_epoch, train_cutoff=0, @@ -240,7 +248,11 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): # 3. Prepare metrics - metrics = {'val_eval': this_val_eval_metrics} + metrics = {'val_eval': this_val_eval_metrics, + 'val_eval_C1': this_val_eval_metrics_C1, + 'val_eval_C2': this_val_eval_metrics_C2, + 'val_eval_C3': this_val_eval_metrics_C3, + 'val_eval_C4': this_val_eval_metrics_C4} else: checkpoint_dict = self.load_checkpoint() # double check uncomment below for testing @@ -249,9 +261,17 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): all_tr_losses, \ all_val_losses, \ all_val_losses_tr_mode, \ - all_val_eval_metrics = checkpoint_dict['plot_stuff'] + all_val_eval_metrics, \ + all_val_eval_metrics_C1, \ + all_val_eval_metrics_C2, + all_val_eval_metrics_C3, \ + all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after - metrics = {'val_eval': all_val_eval_metrics[-1]} + metrics = {'val_eval': all_val_eval_metrics[-1], + 'val_eval_C1': this_val_eval_metrics_C1[-1], + 'val_eval_C2': this_val_eval_metrics_C2[-1], + 'val_eval_C3': this_val_eval_metrics_C3[-1], + 'val_eval_C4': this_val_eval_metrics_C4[-1]} return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) From 5e10b47b623856dce2887d5ee496641ee81ca98d Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 11 Oct 2024 16:52:42 -0700 Subject: [PATCH 183/242] correcting syntax --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 8de8fbc56..2f07e3cd1 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -227,7 +227,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): this_ave_val_loss, \ this_val_eval_metrics, \ this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, + this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=1, @@ -263,7 +263,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): all_val_losses_tr_mode, \ all_val_eval_metrics, \ all_val_eval_metrics_C1, \ - all_val_eval_metrics_C2, + all_val_eval_metrics_C2, \ all_val_eval_metrics_C3, \ all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after From efd4e40d405fe7c5879dfc32f11b62782cb2d836 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 11 Oct 2024 17:28:29 -0700 Subject: [PATCH 184/242] nnunet training function should return more --- examples/fl_post/fl/project/src/nnunet_v1.py | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index ffab4f131..96ef46b2e 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -273,10 +273,14 @@ def __init__(self, **kwargs): pass batches_applied_train, \ - batches_applied_val, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics = trainer.run_training(train_cutoff=train_cutoff, + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, val_cutoff=val_cutoff, val_epoch=val_epoch, train_epoch=train_epoch) @@ -290,6 +294,14 @@ def __init__(self, **kwargs): train_completed = batches_applied_train / float(num_train_batches_per_epoch) val_completed = batches_applied_val / float(num_val_batches_per_epoch) - return train_completed, val_completed, this_ave_train_loss, this_ave_val_loss, this_val_eval_metrics + return batches_applied_train, \ + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 From b4dc6603ba30a84e5d5885f6c012aaace95ca68e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Sun, 13 Oct 2024 13:32:19 -0700 Subject: [PATCH 185/242] syntax --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 2f07e3cd1..dc2f25578 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -164,15 +164,13 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # FIXME: we need to understand how to use round_num instead of current_epoch # this will matter in straggler handling cases # TODO: Should we put this in a separate process? - # TODO: Currently allowing at most 1 second of valiation over one batch in order to avoid NNUnet code throwing exception due - # to empty val results train_completed, \ val_completed, \ this_ave_train_loss, \ this_ave_val_loss, \ this_val_eval_metrics, \ this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, + this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, epochs=epochs, From 6373de5253be3e18c7bff30865839c2bdb3f0177 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Sun, 13 Oct 2024 13:50:28 -0700 Subject: [PATCH 186/242] corrected return to be train and val completed as opposed to batches trained and validated --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 96ef46b2e..6a08edaec 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -294,8 +294,8 @@ def __init__(self, **kwargs): train_completed = batches_applied_train / float(num_train_batches_per_epoch) val_completed = batches_applied_val / float(num_val_batches_per_epoch) - return batches_applied_train, \ - batches_applied_val, \ + return train_completed, \ + val_completed, \ this_ave_train_loss, \ this_ave_val_loss, \ this_val_eval_metrics, \ From a715b15a9923e31de55b9cd633aca979596f35c6 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Sun, 13 Oct 2024 14:08:15 -0700 Subject: [PATCH 187/242] typo --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index dc2f25578..cd4e7f17e 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -266,10 +266,10 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after metrics = {'val_eval': all_val_eval_metrics[-1], - 'val_eval_C1': this_val_eval_metrics_C1[-1], - 'val_eval_C2': this_val_eval_metrics_C2[-1], - 'val_eval_C3': this_val_eval_metrics_C3[-1], - 'val_eval_C4': this_val_eval_metrics_C4[-1]} + 'val_eval_C1': all_val_eval_metrics_C1[-1], + 'val_eval_C2': all_val_eval_metrics_C2[-1], + 'val_eval_C3': all_val_eval_metrics_C3[-1], + 'val_eval_C4': all_val_eval_metrics_C4[-1]} return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) From a958151dbe87d8f89bfa076bd36c9e668028eb6d Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 14 Oct 2024 10:12:09 -0700 Subject: [PATCH 188/242] small cleanup --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- examples/fl_post/fl/project/src/runner_pt_chkpt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index cd4e7f17e..4a99a47be 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards """ # TODO: Clean up imports diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index ac7568c20..fec3a0782 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards """ # TODO: Clean up imports From 2dcd79264e850bc10d1ffd4691a6d805716873dc Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 16 Oct 2024 12:43:59 -0700 Subject: [PATCH 189/242] setting train and val cutoff in plan, and setting defaults to infinity --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 2 ++ examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index d6029fa8b..08f6db8f9 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -34,6 +34,8 @@ task_runner : device : cuda gpu_num_string : '0' nnunet_task : Task537_FLPost + train_cutoff : 100 + val_cutoff : 3 network : defaults : plan/defaults/network.yaml diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 4a99a47be..6931757be 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -33,7 +33,7 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_cutoff=100, + train_cutoff=np.inf, val_cutoff=np.inf, nnunet_task=None, config_path=None, From a98b2ef576b4831624a17013b274d605f496bbf9 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 17 Oct 2024 15:04:54 -0700 Subject: [PATCH 190/242] post review with Micah --- .../fl/mlcube/workspace/training_config.yaml | 11 ++++---- examples/fl_post/fl/project/be_Dockerfile | 4 +-- examples/fl_post/fl/project/nnunet_setup.py | 2 +- examples/fl_post/fl/project/requirements.txt | 2 +- .../fl/project/src/nnunet_dummy_dataloader.py | 16 +++--------- examples/fl_post/fl/project/src/nnunet_v1.py | 9 ++++--- .../fl_post/fl/project/src/runner_nnunetv1.py | 26 ++++++++----------- 7 files changed, 30 insertions(+), 40 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 952113060..6c95021ca 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -30,11 +30,12 @@ task_runner : defaults : plan/defaults/task_runner.yaml template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner settings : - device : cuda - gpu_num_string : '0' - nnunet_task : Task537_FLPost - train_cutoff : 100 - val_cutoff : 3 + device : cuda + gpu_num_string : '0' + nnunet_task : Task537_FLPost + train_cutoff : 100 + val_cutoff : 3 + actual_max_num_epochs : 1000 network : defaults : plan/defaults/network.yaml diff --git a/examples/fl_post/fl/project/be_Dockerfile b/examples/fl_post/fl/project/be_Dockerfile index b69f527c2..874266e3f 100644 --- a/examples/fl_post/fl/project/be_Dockerfile +++ b/examples/fl_post/fl/project/be_Dockerfile @@ -6,14 +6,14 @@ ENV CUDA_VISIBLE_DEVICES="0" # ENV https_proxy="http://proxy-us.intel.com:912" ENV no_proxy=localhost,spr-gpu01.jf.intel.com -ENV no_proxy________________="http://proxy-us.intel.com:912" +ENV no_proxy_________________________________________________________________________="http://proxy-us.intel.com:912" # install project dependencies RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 -COPY ./requirements.txt /mlcube_project/requirements.txt +COPY ./be_requirements.txt /mlcube_project/requirements.txt RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt # Create similar env with cuda118 diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 5ce499996..81f6787d8 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -28,7 +28,7 @@ def main(postopp_pardir, plans_path=None, local_plans_identifier=local_plans_identifier, shared_plans_identifier=shared_plans_identifier, - overwrite_nnunet_datadirs=False, + overwrite_nnunet_datadirs=True, timestamp_selection='all', cuda_device='0', verbose=False): diff --git a/examples/fl_post/fl/project/requirements.txt b/examples/fl_post/fl/project/requirements.txt index ef2281ff7..8f03308f9 100644 --- a/examples/fl_post/fl/project/requirements.txt +++ b/examples/fl_post/fl/project/requirements.txt @@ -1,4 +1,4 @@ onnx==1.13.0 typer==0.9.0 -git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@supporting_partial_epochs +git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@main#egg=nnunet numpy==1.26.4 diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py index a757b338c..68cbbbc40 100644 --- a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -12,24 +12,16 @@ import os class NNUNetDummyDataLoader(): - def __init__(self, data_path, p_train, partial_epoch=1.0): + def __init__(self, data_path, p_train): self.task_name = data_path data_base_path = os.path.join(os.environ['nnUNet_preprocessed'], self.task_name) with open(f'{data_base_path}/dataset.json', 'r') as f: data_config = json.load(f) data_size = data_config['numTraining'] - # NOTE: Intended use with PyTorchNNUNetCheckpointTaskRunner where partial_epoch scales down num_train_batches_per_epoch - # and num_val_batches_per_epoch. NNUnet loaders sample batches with replacement. Ignoring rounding (int()), - # the 'data sizes' below are divided by batch_size to obtain the number of batches used per epoch. - # These 'data sizes' therefore establish correct relative weights for train and val result aggregation over collaboarators - # due to the fact that batch_size is equal across all collaborators. In addition, over many rounds each data point - # at a particular collaborator informs the results with equal measure. In particular, the average number of times (over - # repeated runs of the federation) that a particular sample is used for a training or val result - # over the corse of the whole federation is given by the 'data sizes' below. # TODO: determine how nnunet validation splits round - self.train_data_size = int(partial_epoch * p_train * data_size) - self.valid_data_size = int(partial_epoch * (1 - p_train) * data_size) + self.train_data_size = int(p_train * data_size) + self.valid_data_size = data_size - self.train_data_size def get_feature_shape(self): return [1,1,1] @@ -41,4 +33,4 @@ def get_valid_data_size(self): return self.valid_data_size def get_task_name(self): - return self.task_name + return self.task_name \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 6a08edaec..cbb1d611e 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -54,7 +54,7 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def train_nnunet(TOTAL_max_num_epochs, +def train_nnunet(actual_max_num_epochs, epochs, current_epoch, val_epoch=True, @@ -83,7 +83,8 @@ def train_nnunet(TOTAL_max_num_epochs, pretrained_weights=None): """ - TOTAL_max_num_epochs (int): Provides the total number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during the course of federated training) + actual_max_num_epochs (int): Provides the number of epochs intended to be trained + (this needs to be held constant outside of individual calls to this function during with max_num_epochs is set to one more than the current epoch) epochs (int): Number of epochs to trainon top of current epoch current_epoch (int): Which epoch will be used to grab the model val_epoch (bool) : Will validation be performed @@ -211,7 +212,7 @@ def __init__(self, **kwargs): trainer = trainer_class( plans_file, fold, - TOTAL_max_num_epochs=TOTAL_max_num_epochs, + actual_max_num_epochs=actual_max_num_epochs, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, @@ -259,7 +260,7 @@ def __init__(self, **kwargs): return if find_lr: - trainer.find_lr(num_iters=self.TOTAL_max_num_epochs) + trainer.find_lr(num_iters=self.actual_max_num_epochs) else: if not validation_only: if args.continue_training: diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 6931757be..5d93fe344 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -37,7 +37,7 @@ def __init__(self, val_cutoff=np.inf, nnunet_task=None, config_path=None, - TOTAL_max_num_epochs=1000, + actual_max_num_epochs=1000, **kwargs): """Initialize. @@ -46,7 +46,7 @@ def __init__(self, val_cutoff (int) : Total time (in seconds) allowed for iterating over val batches (plus or minus one iteration since check willl be once an iteration). nnunet_task (str) : Task string used to identify the data and model folders config_path(str) : Path to the configuration file used by the training and validation script. - TOTAL_max_num_epochs (int) : Total number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating + actual_max_num_epochs (int) : Number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). TODO: """ @@ -80,7 +80,7 @@ def __init__(self, self.train_cutoff = train_cutoff self.val_cutoff = val_cutoff self.config_path = config_path - self.TOTAL_max_num_epochs=TOTAL_max_num_epochs + self.actual_max_num_epochs=actual_max_num_epochs # self.task_completed is a dictionary of task to amount completed as a float in [0,1] # Values will be dynamically updated @@ -172,7 +172,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): this_val_eval_metrics_C1, \ this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, epochs=epochs, current_epoch=current_epoch, train_cutoff=self.train_cutoff, @@ -227,7 +227,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): this_val_eval_metrics_C1, \ this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 = train_nnunet(TOTAL_max_num_epochs=self.TOTAL_max_num_epochs, + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, epochs=1, current_epoch=current_epoch, train_cutoff=0, @@ -286,7 +286,7 @@ def load_metrics(self, filepath): """ - def get_train_data_size(self, task_dependent=False, task_name=None): + def get_train_data_size(self, task_name=None): """Get the number of training examples. It will be used for weighted averaging in aggregation. @@ -296,16 +296,14 @@ def get_train_data_size(self, task_dependent=False, task_name=None): Returns: int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema """ - if not task_dependent: + if not task_name: return self.data_loader.get_train_data_size() - elif not task_name: - raise ValueError(f"If using task dependent data size, must provide task_name.") else: # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] - return int(np.ceil(self.task_completed[task_name] * self.data_loader.get_train_data_size())) + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_train_data_size())) - def get_valid_data_size(self, task_dependent=False, task_name=None): + def get_valid_data_size(self, task_name=None): """Get the number of training examples. It will be used for weighted averaging in aggregation. @@ -315,10 +313,8 @@ def get_valid_data_size(self, task_dependent=False, task_name=None): Returns: int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema """ - if not task_dependent: + if not task_name: return self.data_loader.get_valid_data_size() - elif not task_name: - raise ValueError(f"If using task dependent data size, must provide task_name.") else: # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] - return int(np.ceil(self.task_completed[task_name] * self.data_loader.get_valid_data_size())) + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_valid_data_size())) From f99a88c8279b21d459990207be433c49ad526292 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 18 Oct 2024 15:31:11 -0700 Subject: [PATCH 191/242] inserting old debug stuff as info (stdout) to have going forward --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 5d93fe344..66eabfdc8 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -186,6 +186,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.task_completed['locally_tuned_model_validation'] = val_completed self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + WORKING HEREself.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + # 3. Prepare metrics metrics = {'train_loss': this_ave_train_loss} From 9bafb04d05c772f2790ed35ceab90e373cdd5a03 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 18 Oct 2024 15:46:28 -0700 Subject: [PATCH 192/242] now inserting info to stdout --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 66eabfdc8..54199e7d0 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -186,7 +186,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.task_completed['locally_tuned_model_validation'] = val_completed self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - WORKING HEREself.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + self.logger.info(f"Data size right now returns {self.get_train_data_size()} for train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") # 3. Prepare metrics From 2bfea20b7e99fed2c86c7729273b0b94f4f617cd Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 18 Oct 2024 15:53:37 -0700 Subject: [PATCH 193/242] changing to 300 rounds for testing --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 2 +- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 6c95021ca..ba2a1bee7 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,7 +5,7 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : 2 + rounds_to_train : 300 admins_endpoints_mapping: testfladmin@example.com: - GetExperimentStatus diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 54199e7d0..8f671c8a3 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -186,7 +186,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.task_completed['locally_tuned_model_validation'] = val_completed self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - self.logger.info(f"Data size right now returns {self.get_train_data_size()} for train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + self.logger.info(f"Data size right now returns {self.get_train_data_size()} for train and {self.get_valid_data_size()} for val.\n") # 3. Prepare metrics From c8002adc8a3b8d91aa93ccf4f831406673f3fbe8 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 18 Oct 2024 16:01:55 -0700 Subject: [PATCH 194/242] was not printing at the time info was properly populated --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 8f671c8a3..c9c5c22af 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -185,14 +185,15 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.task_completed['train'] = train_completed self.task_completed['locally_tuned_model_validation'] = val_completed - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - self.logger.info(f"Data size right now returns {self.get_train_data_size()} for train and {self.get_valid_data_size()} for val.\n") - - # 3. Prepare metrics metrics = {'train_loss': this_ave_train_loss} - return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) + global_tensor_dict, local_tensor_dict = self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + self.logger.info(f"Data size right now returns {self.get_train_data_size()} for train and {self.get_valid_data_size()} for val.\n") + + return global_tensor_dict, local_tensor_dict def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False, **kwargs): From 1482e9f1c80202aeebbaf5499b199c16cbee3d4d Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 18 Oct 2024 17:28:39 -0700 Subject: [PATCH 195/242] now using round and not saving any checkpoints --- examples/fl_post/fl/project/src/nnunet_v1.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index cbb1d611e..c5d2cb282 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -55,8 +55,7 @@ def seed_everything(seed=1234): def train_nnunet(actual_max_num_epochs, - epochs, - current_epoch, + round, val_epoch=True, train_epoch=True, train_cutoff=np.inf, @@ -85,8 +84,7 @@ def train_nnunet(actual_max_num_epochs, """ actual_max_num_epochs (int): Provides the number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during with max_num_epochs is set to one more than the current epoch) - epochs (int): Number of epochs to trainon top of current epoch - current_epoch (int): Which epoch will be used to grab the model + round (int): Federated round, equal to the epoch used for the model (lr scheduling) val_epoch (bool) : Will validation be performed train_epoch (bool) : Will training run (rather than val only) makes lr step and epoch increment train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. @@ -131,6 +129,8 @@ def train_nnunet(actual_max_num_epochs, "Optional. Beta. Use with caution." disable_next_stage_pred: If True, do not predict next stage """ + # hard coded, as internal trainer.run_training is currently written to run only one epoch + epochs = 1 class Arguments(): def __init__(self, **kwargs): @@ -237,8 +237,8 @@ def __init__(self, **kwargs): True # if false it will not store/overwrite _latest but separate files each ) - trainer.max_num_epochs = current_epoch + epochs - trainer.epoch = current_epoch + trainer.max_num_epochs = round + epochs + trainer.epoch = round trainer.initialize(not validation_only) From 227cac21cdfa43fd1e09365042fc70def2f61e88 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 12:36:36 -0700 Subject: [PATCH 196/242] I want to be explicit about where we are writing checkpoints now --- examples/fl_post/fl/project/src/runner_pt_chkpt.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index fec3a0782..a9427d88e 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -82,12 +82,10 @@ def __init__(self, self.replace_checkpoint(self.checkpoint_path_initial) - def load_checkpoint(self, checkpoint_path=None, map_location=None): + def load_checkpoint(self, map_location=None, checkpoint_path): """ Function used to load checkpoint from disk. """ - if not checkpoint_path: - checkpoint_path = self.checkpoint_path_load checkpoint_dict = torch.load(checkpoint_path, map_location=map_location) return checkpoint_dict @@ -124,7 +122,7 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): return self.required_tensorkeys_for_function[func_name] def reset_opt_vars(self): - current_checkpoint_dict = self.load_checkpoint() + current_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) initial_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_initial) derived_opt_state_dict = self._get_optimizer_state(checkpoint_dict=initial_checkpoint_dict) self._set_optimizer_state(derived_opt_state_dict=derived_opt_state_dict, @@ -172,7 +170,7 @@ def read_tensors_from_checkpoint(self, with_opt_vars): dict: Tensor dictionary {**dict, **optimizer_dict} """ - checkpoint_dict = self.load_checkpoint() + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) state = to_cpu_numpy(checkpoint_dict['state_dict']) if with_opt_vars: opt_state = self._get_optimizer_state(checkpoint_dict=checkpoint_dict) From b90f732e2ab33d44cbd813686e1ea640b2b3dc59 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 12:38:47 -0700 Subject: [PATCH 197/242] round instead of current_epoch --- .../fl_post/fl/project/src/runner_nnunetv1.py | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index c9c5c22af..5b68aec5b 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -155,15 +155,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - self.logger.info(f"In col train method, loaded checkpoint with current epoch: {current_epoch}") - # 2. Train/val function existing externally - # Some todo inside function below - # TODO: test for off-by-one error - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Training for round:{round_num}") train_completed, \ val_completed, \ this_ave_train_loss, \ @@ -174,7 +167,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, epochs=epochs, - current_epoch=current_epoch, + round=round, train_cutoff=self.train_cutoff, val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), @@ -213,15 +206,10 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): if not from_checkpoint: self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - self.logger.info(f"In col val method, loaded checkpoint with current epoch: {current_epoch}") + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Validating for round:{round_num}") # 2. Train/val function existing externally # Some todo inside function below - # TODO: test for off-by-one error - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? train_completed, \ val_completed, \ this_ave_train_loss, \ @@ -232,7 +220,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, epochs=1, - current_epoch=current_epoch, + round=round_num, train_cutoff=0, val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), From 5d1c7273d388a626e79fd25fa27d125822216ce7 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 12:59:24 -0700 Subject: [PATCH 198/242] some clean up --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 5b68aec5b..b2387e55b 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -153,6 +153,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" + # epochs is not used, inside function is hard coded for epochs=1 + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) @@ -166,7 +168,6 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, - epochs=epochs, round=round, train_cutoff=self.train_cutoff, val_cutoff = self.val_cutoff, @@ -219,7 +220,6 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, - epochs=1, round=round_num, train_cutoff=0, val_cutoff = self.val_cutoff, @@ -244,7 +244,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): 'val_eval_C4': this_val_eval_metrics_C4} else: checkpoint_dict = self.load_checkpoint() - # double check uncomment below for testing + # double check: uncomment below for testing # compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") all_tr_losses, \ From e774fe9ee53d2d1e912e94fd315a15a886bf82bd Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 13:53:04 -0700 Subject: [PATCH 199/242] default argument at end --- examples/fl_post/fl/project/src/runner_pt_chkpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index a9427d88e..a7fbd2056 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -82,7 +82,7 @@ def __init__(self, self.replace_checkpoint(self.checkpoint_path_initial) - def load_checkpoint(self, map_location=None, checkpoint_path): + def load_checkpoint(self, checkpoint_path, map_location=None): """ Function used to load checkpoint from disk. """ From 7314eef8328ccbe2bb8615687fab1041d5e3c37f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 14:06:30 -0700 Subject: [PATCH 200/242] provide checkpoint path --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index b2387e55b..7a9c9039a 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -124,7 +124,7 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): # get device for correct placement of tensors device = self.device - checkpoint_dict = self.load_checkpoint(map_location=device) + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load, map_location=device) epoch = checkpoint_dict['epoch'] new_state = {} # grabbing keys from the checkpoint state dict, poping from the tensor_dict From 659d024304fb820e2cbb6cabb4b694893bde1b92 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 15:23:28 -0700 Subject: [PATCH 201/242] Nowing using fl_round rather than round to avoid collision with internal function --- examples/fl_post/fl/project/src/nnunet_v1.py | 8 ++++---- examples/fl_post/fl/project/src/runner_nnunetv1.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index c5d2cb282..380f12d25 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -55,7 +55,7 @@ def seed_everything(seed=1234): def train_nnunet(actual_max_num_epochs, - round, + fl_round, val_epoch=True, train_epoch=True, train_cutoff=np.inf, @@ -84,7 +84,7 @@ def train_nnunet(actual_max_num_epochs, """ actual_max_num_epochs (int): Provides the number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during with max_num_epochs is set to one more than the current epoch) - round (int): Federated round, equal to the epoch used for the model (lr scheduling) + fl_round (int): Federated round, equal to the epoch used for the model (lr scheduling) val_epoch (bool) : Will validation be performed train_epoch (bool) : Will training run (rather than val only) makes lr step and epoch increment train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. @@ -237,8 +237,8 @@ def __init__(self, **kwargs): True # if false it will not store/overwrite _latest but separate files each ) - trainer.max_num_epochs = round + epochs - trainer.epoch = round + trainer.max_num_epochs = fl_round + epochs + trainer.epoch = fl_round trainer.initialize(not validation_only) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 7a9c9039a..ce6ef7c79 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -168,7 +168,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, - round=round, + fl_round=round_num, train_cutoff=self.train_cutoff, val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), @@ -220,7 +220,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): this_val_eval_metrics_C2, \ this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, - round=round_num, + fl_round=round_num, train_cutoff=0, val_cutoff = self.val_cutoff, task=self.data_loader.get_task_name(), From ccc5fcad3a7c2fc9161613fe69747cbfdfb21450 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 16:27:40 -0700 Subject: [PATCH 202/242] missed another spot where checkpoint path was needed --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index ce6ef7c79..f71dfbc4e 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -243,7 +243,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): 'val_eval_C3': this_val_eval_metrics_C3, 'val_eval_C4': this_val_eval_metrics_C4} else: - checkpoint_dict = self.load_checkpoint() + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) # double check: uncomment below for testing # compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") From 11617546315f3d5605e684286ad1d32c1719a236 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 21 Oct 2024 19:00:53 -0700 Subject: [PATCH 203/242] setting of epoch pre nnunet train/val corresponds to previous epoch trained --- examples/fl_post/fl/project/src/nnunet_v1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 380f12d25..965c353c5 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -238,7 +238,8 @@ def __init__(self, **kwargs): ) trainer.max_num_epochs = fl_round + epochs - trainer.epoch = fl_round + # previous epoch trained + trainer.epoch = fl_round - 1 trainer.initialize(not validation_only) From 4f998fd276dbc56586320c7595bce51a5cef1d5f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 09:47:25 -0700 Subject: [PATCH 204/242] inserting some debug prints --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 965c353c5..6ec4dbc65 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -243,6 +243,8 @@ def __init__(self, **kwargs): trainer.initialize(not validation_only) + print(f"Brandon DEBUG - after trainer initialization trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data # point is expected to be seen epochs number of times @@ -274,6 +276,8 @@ def __init__(self, **kwargs): # new training without pretraine weights, do nothing pass + print(f"Brandon DEBUG - right before train/val call trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") + batches_applied_train, \ batches_applied_val, \ this_ave_train_loss, \ From d6e7f4db100499ed69c62c5ce055cfab0556f33e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 10:26:38 -0700 Subject: [PATCH 205/242] indentation issues --- examples/fl_post/fl/project/src/nnunet_v1.py | 118 ++++++++++--------- 1 file changed, 60 insertions(+), 58 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 6ec4dbc65..e6e12cc13 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -221,40 +221,10 @@ def __init__(self, **kwargs): deterministic=deterministic, fp16=run_mixed_precision, ) - # we want latest checkoint only (not best or any intermediate) - trainer.save_final_checkpoint = ( - True # whether or not to save the final checkpoint - ) - trainer.save_best_checkpoint = ( - False # whether or not to save the best checkpoint according to - ) - # self.best_val_eval_criterion_MA - trainer.save_intermediate_checkpoints = ( - False # whether or not to save checkpoint_latest. We need that in case - ) - # the training chashes - trainer.save_latest_only = ( - True # if false it will not store/overwrite _latest but separate files each - ) - - trainer.max_num_epochs = fl_round + epochs - # previous epoch trained - trainer.epoch = fl_round - 1 + trainer.initialize(not validation_only) - print(f"Brandon DEBUG - after trainer initialization trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") - - # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data - # point is expected to be seen epochs number of times - - num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) - num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) - - # the nnunet trainer attributes have a different naming convention than I am using - trainer.num_batches_per_epoch = num_train_batches_per_epoch - trainer.num_val_batches_per_epoch = num_val_batches_per_epoch - if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": trainer.save_checkpoint( join(trainer.output_folder, "model_final_checkpoint.model") @@ -275,21 +245,6 @@ def __init__(self, **kwargs): else: # new training without pretraine weights, do nothing pass - - print(f"Brandon DEBUG - right before train/val call trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") - - batches_applied_train, \ - batches_applied_val, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics, \ - this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, \ - this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, - val_cutoff=val_cutoff, - val_epoch=val_epoch, - train_epoch=train_epoch) else: # if valbest: # trainer.load_best_checkpoint(train=False) @@ -297,17 +252,64 @@ def __init__(self, **kwargs): # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() - train_completed = batches_applied_train / float(num_train_batches_per_epoch) - val_completed = batches_applied_val / float(num_val_batches_per_epoch) - - return train_completed, \ - val_completed, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics, \ - this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, \ - this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 + print(f"Brandon DEBUG - after trainer initialization an model load, trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") + + # we want latest checkoint only (not best or any intermediate) + trainer.save_final_checkpoint = ( + True # whether or not to save the final checkpoint + ) + trainer.save_best_checkpoint = ( + False # whether or not to save the best checkpoint according to + ) + # self.best_val_eval_criterion_MA + trainer.save_intermediate_checkpoints = ( + False # whether or not to save checkpoint_latest. We need that in case + ) + # the training chashes + trainer.save_latest_only = ( + True # if false it will not store/overwrite _latest but separate files each + ) + + trainer.max_num_epochs = fl_round + epochs + # previous epoch trained + trainer.epoch = fl_round + + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data + # point is expected to be seen epochs number of times + + num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) + + # the nnunet trainer attributes have a different naming convention than I am using + trainer.num_batches_per_epoch = num_train_batches_per_epoch + trainer.num_val_batches_per_epoch = num_val_batches_per_epoch + + print(f"Brandon DEBUG - right before train/val call trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") + + batches_applied_train, \ + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, + val_cutoff=val_cutoff, + val_epoch=val_epoch, + train_epoch=train_epoch) + + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = batches_applied_val / float(num_val_batches_per_epoch) + + return train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 From ee457671e47de32d6a676226e81beec10b04cc9e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 12:02:52 -0700 Subject: [PATCH 206/242] removing some debug prints --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index e6e12cc13..938dda48f 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -252,8 +252,6 @@ def __init__(self, **kwargs): # trainer.load_final_checkpoint(train=False) trainer.load_latest_checkpoint() - print(f"Brandon DEBUG - after trainer initialization an model load, trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") - # we want latest checkoint only (not best or any intermediate) trainer.save_final_checkpoint = ( True # whether or not to save the final checkpoint @@ -284,8 +282,6 @@ def __init__(self, **kwargs): trainer.num_batches_per_epoch = num_train_batches_per_epoch trainer.num_val_batches_per_epoch = num_val_batches_per_epoch - print(f"Brandon DEBUG - right before train/val call trainer.epoch:{trainer.epoch}, trainer.max_num_epochs:{trainer.max_num_epochs}") - batches_applied_train, \ batches_applied_val, \ this_ave_train_loss, \ From 43a08eff7accf77df47df08592173acc0f83bdbf Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 12:55:23 -0700 Subject: [PATCH 207/242] removing a print --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index f71dfbc4e..f2beff0cb 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -185,8 +185,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): global_tensor_dict, local_tensor_dict = self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - self.logger.info(f"Data size right now returns {self.get_train_data_size()} for train and {self.get_valid_data_size()} for val.\n") - + return global_tensor_dict, local_tensor_dict From fb69b716ee44129124a9b33e4f94c0f57a6dc324 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 13:27:37 -0700 Subject: [PATCH 208/242] changes due to new signature for nnunet train function --- examples/fl_post/fl/project/nnunet_model_setup.py | 11 ++++++----- examples/fl_post/fl/project/src/nnunet_v1.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 5f31bfec6..2712ccc08 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -7,12 +7,13 @@ def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): os.environ['CUDA_VISIBLE_DEVICES']=cuda_device - print(f"###########\nStarting training for task: {task}\n") - train_nnunet(epochs=1, - current_epoch = current_epoch, - network = network, + print(f"###########\nStarting training a single epoch for task: {task}\n") + # Function below is now hard coded for a single epoch of training. + train_nnunet(actual_max_num_epochs=1000, + fl_round=current_epoch, + network=network, task=task, - network_trainer = network_trainer, + network_trainer=network_trainer, fold=fold, continue_training=continue_training, p=plans_identifier) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 938dda48f..e539a8913 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -269,7 +269,6 @@ def __init__(self, **kwargs): ) trainer.max_num_epochs = fl_round + epochs - # previous epoch trained trainer.epoch = fl_round # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data From 131dfd10859c88ea343001c606a1593db6ac95de Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 13:36:03 -0700 Subject: [PATCH 209/242] preparing to test on brain threshold data --- examples/fl_post/fl/be_setup_test_no_docker.sh | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index 19713ae78..cdaf15ccf 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -199,14 +199,16 @@ mkdir mlcube_col5/workspace/data # this is the one I had success running on #DATA_DIRS=test_data_small_from_hasan -SIZE="hundred" -#SIZE="thousand" - -DATA_DIR_1="test_${SIZE}_BraTS20_3square_0" -DATA_DIR_2="test_${SIZE}_BraTS20_3square_1" -DATA_DIR_3="test_${SIZE}_BraTS20_3square_2" -DATA_DIR_4="test_${SIZE}_BraTS20_3square_3" -DATA_DIR_5="test_${SIZE}_BraTS20_3square_4" +#SIZE="hundred" +#SUPPLEMENT="square" +SUPPLEMENT="thresholdbrainsorted" +SIZE="thousand" + +DATA_DIR_1="test_${SIZE}_BraTS20_3${SUPPLEMENT}_0" +DATA_DIR_2="test_${SIZE}_BraTS20_3${SUPPLEMENT}_1" +DATA_DIR_3="test_${SIZE}_BraTS20_3${SUPPLEMENT}_2" +DATA_DIR_4="test_${SIZE}_BraTS20_3${SUPPLEMENT}_3" +DATA_DIR_5="test_${SIZE}_BraTS20_3${SUPPLEMENT}_4" cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_1/labels/* mlcube_col1/workspace/labels cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_1/data/* mlcube_col1/workspace/data From fc60385df4bc1360baf8b34c0dc731c31a26290e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 22 Oct 2024 14:30:22 -0700 Subject: [PATCH 210/242] changes to training config for testing --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index ba2a1bee7..8d9ad4967 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -33,9 +33,9 @@ task_runner : device : cuda gpu_num_string : '0' nnunet_task : Task537_FLPost - train_cutoff : 100 - val_cutoff : 3 - actual_max_num_epochs : 1000 + train_cutoff : np.inf + val_cutoff : np.inf + actual_max_num_epochs : 300 network : defaults : plan/defaults/network.yaml From 12303825bdf258e525939504b746b98d9a8bf261 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 23 Oct 2024 11:58:31 -0700 Subject: [PATCH 211/242] another iteration of data creation (so changing data file names) --- examples/fl_post/fl/be_setup_test_no_docker.sh | 3 ++- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh index cdaf15ccf..75f5265c0 100644 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ b/examples/fl_post/fl/be_setup_test_no_docker.sh @@ -201,7 +201,8 @@ mkdir mlcube_col5/workspace/data #SIZE="hundred" #SUPPLEMENT="square" -SUPPLEMENT="thresholdbrainsorted" +#SUPPLEMENT="thresholdbrainsorted" +SUPPLEMENT="thresholdbrainandsquaresorted" SIZE="thousand" DATA_DIR_1="test_${SIZE}_BraTS20_3${SUPPLEMENT}_0" diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 8d9ad4967..bea86ab4e 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -33,8 +33,8 @@ task_runner : device : cuda gpu_num_string : '0' nnunet_task : Task537_FLPost - train_cutoff : np.inf - val_cutoff : np.inf + train_cutoff : 6000 + val_cutoff : 6000 actual_max_num_epochs : 300 network : @@ -90,4 +90,4 @@ straggler_handling_policy : template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling settings : straggler_cutoff_time : 600 - minimum_reporting : 2 \ No newline at end of file + minimum_reporting : 5 \ No newline at end of file From b5b6fa4ce62024860a12cd95f35687653bc3e5c0 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Wed, 23 Oct 2024 22:15:23 -0700 Subject: [PATCH 212/242] changes new config and parameter definitions --- .../fl/mlcube/workspace/training_config.yaml | 23 +++++++++++++++++-- .../fl_post/fl/project/src/runner_nnunetv1.py | 16 ++++--------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index bea86ab4e..cb1bbf476 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -10,6 +10,26 @@ aggregator : testfladmin@example.com: - GetExperimentStatus - SetStragglerCuttoffTime + - SetDynamicTaskArg + - GetDynamicTaskArg + dynamictaskargs: &dynamictaskargs + train: + train_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day + aggregated_model_validation: + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day collaborator : @@ -18,6 +38,7 @@ collaborator : settings : delta_updates : false opt_treatment : CONTINUE_LOCAL + dynamictaskargs: *dynamictaskargs data_loader : defaults : plan/defaults/data_loader.yaml @@ -33,8 +54,6 @@ task_runner : device : cuda gpu_num_string : '0' nnunet_task : Task537_FLPost - train_cutoff : 6000 - val_cutoff : 6000 actual_max_num_epochs : 300 network : diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index f2beff0cb..db84b0acd 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -33,8 +33,6 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): pull model state from a PyTorch checkpoint.""" def __init__(self, - train_cutoff=np.inf, - val_cutoff=np.inf, nnunet_task=None, config_path=None, actual_max_num_epochs=1000, @@ -42,8 +40,6 @@ def __init__(self, """Initialize. Args: - train_cutoff (int) : Total time (in seconds) allowed for iterating over train batches (plus or minus one iteration since check willl be once an iteration). - val_cutoff (int) : Total time (in seconds) allowed for iterating over val batches (plus or minus one iteration since check willl be once an iteration). nnunet_task (str) : Task string used to identify the data and model folders config_path(str) : Path to the configuration file used by the training and validation script. actual_max_num_epochs (int) : Number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating @@ -77,8 +73,6 @@ def __init__(self, **kwargs, ) - self.train_cutoff = train_cutoff - self.val_cutoff = val_cutoff self.config_path = config_path self.actual_max_num_epochs=actual_max_num_epochs @@ -149,7 +143,7 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): return epoch - def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): + def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time, train_cutoff_time, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" @@ -169,8 +163,8 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): this_val_eval_metrics_C3, \ this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, fl_round=round_num, - train_cutoff=self.train_cutoff, - val_cutoff = self.val_cutoff, + train_cutoff=train_cutoff_time, + val_cutoff = val_cutoff_time, task=self.data_loader.get_task_name(), val_epoch=True, train_epoch=True) @@ -189,7 +183,7 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): return global_tensor_dict, local_tensor_dict - def validate(self, col_name, round_num, input_tensor_dict, from_checkpoint=False, **kwargs): + def validate(self, col_name, round_num, input_tensor_dict, val_cutoff_time=np.inf, from_checkpoint=False, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" @@ -221,7 +215,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, fl_round=round_num, train_cutoff=0, - val_cutoff = self.val_cutoff, + val_cutoff = val_cutoff_time, task=self.data_loader.get_task_name(), val_epoch=True, train_epoch=False) From 4e2aae1e4c4ae1a9f073aeb6aa28252266a9bc75 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 25 Oct 2024 17:53:59 -0700 Subject: [PATCH 213/242] enabling dampening of the train_completion with admin control --- .../fl_post/fl/mlcube/workspace/training_config.yaml | 7 +++++++ examples/fl_post/fl/project/src/runner_nnunetv1.py | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index cb1bbf476..bd9daf629 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -24,12 +24,19 @@ aggregator : min: 10 # 10 seconds max: 86400 # one day value: 86400 # one day + train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) + admin_settable: True + min: 1e-2 # shifts non 0.0 completion rates much closer to 1.0 + max: 1.0 # leaves completion rates as is + value: 1.0 + aggregated_model_validation: val_cutoff_time: admin_settable: True min: 10 # 10 seconds max: 86400 # one day value: 86400 # one day + weights_alpha: *weights_alpha collaborator : diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index db84b0acd..96cb257cc 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -143,7 +143,7 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): return epoch - def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time, train_cutoff_time, **kwargs): + def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time, train_cutoff_time, train_completion_dampener, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" @@ -169,6 +169,14 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time, val_epoch=True, train_epoch=True) + # dampen the train_completion + """ + values in range: (0, 1] with values near 0.0 making all train_completion rates shift nearer to 1.0, thus making the + trained model update weighting during aggregation stay closer to the plain data size weighting + specifically, update_weight = train_data_size / train_completed**train_completion_dampener + """ + train_completed = train_completed**train_completion_dampener + # update amount of task completed self.task_completed['train'] = train_completed self.task_completed['locally_tuned_model_validation'] = val_completed From 091930821296fa37550816850833270c40eeb302 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Fri, 25 Oct 2024 18:16:01 -0700 Subject: [PATCH 214/242] left a stray idea in the training config that does not belong --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index bd9daf629..2078993b5 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -36,7 +36,6 @@ aggregator : min: 10 # 10 seconds max: 86400 # one day value: 86400 # one day - weights_alpha: *weights_alpha collaborator : From 243f024523db1b6b85f86fb1d9c4f6b89eedd5cb Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Sun, 27 Oct 2024 13:47:09 -0700 Subject: [PATCH 215/242] setting min train_completion_dampener to -1 to allow multiplication instead of division by amount of training completed when modifying data size for update weighting, also allowing col1 to be admin for testing admin control of dampener as well as train and val cutoff time for train (training and local model validation) --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 2078993b5..6456c3a02 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -12,6 +12,9 @@ aggregator : - SetStragglerCuttoffTime - SetDynamicTaskArg - GetDynamicTaskArg + col1@example.com: + - SetDynamicTaskArg + dynamictaskargs: &dynamictaskargs train: train_cutoff_time: @@ -24,9 +27,9 @@ aggregator : min: 10 # 10 seconds max: 86400 # one day value: 86400 # one day - train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) + train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0 admin_settable: True - min: 1e-2 # shifts non 0.0 completion rates much closer to 1.0 + min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) max: 1.0 # leaves completion rates as is value: 1.0 From 5bff371d975c98ff3c79e631ffa39696c549805f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 10:39:56 -0700 Subject: [PATCH 216/242] removing be scripts from branch --- examples/fl_post/fl/be_setup_clean.sh | 15 - .../fl_post/fl/be_setup_test_no_docker.sh | 258 ------------------ examples/fl_post/fl/be_test.sh | 57 ---- examples/fl_post/fl/intel_build.sh | 22 -- 4 files changed, 352 deletions(-) delete mode 100644 examples/fl_post/fl/be_setup_clean.sh delete mode 100644 examples/fl_post/fl/be_setup_test_no_docker.sh delete mode 100755 examples/fl_post/fl/be_test.sh delete mode 100755 examples/fl_post/fl/intel_build.sh diff --git a/examples/fl_post/fl/be_setup_clean.sh b/examples/fl_post/fl/be_setup_clean.sh deleted file mode 100644 index 5db13b158..000000000 --- a/examples/fl_post/fl/be_setup_clean.sh +++ /dev/null @@ -1,15 +0,0 @@ - -HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" - -cd $HOMEDIR - - - -rm -rf ./mlcube_agg -rm -rf ./mlcube_col1 -rm -rf ./mlcube_col2 -rm -rf ./mlcube_col3 -rm -rf ./mlcube_col4 -rm -rf ./mlcube_col5 -rm -rf ./ca -rm -rf ./for_admin diff --git a/examples/fl_post/fl/be_setup_test_no_docker.sh b/examples/fl_post/fl/be_setup_test_no_docker.sh deleted file mode 100644 index 75f5265c0..000000000 --- a/examples/fl_post/fl/be_setup_test_no_docker.sh +++ /dev/null @@ -1,258 +0,0 @@ -while getopts t flag; do - case "${flag}" in - t) TWO_COL_SAME_CERT="true" ;; - esac -done -TWO_COL_SAME_CERT="${TWO_COL_SAME_CERT:-false}" - -COL1_CN="col1@example.com" -COL2_CN="col2@example.com" -COL3_CN="col3@example.com" -COL4_CN="col4@example.com" -COL5_CN="col5@example.com" - -COL1_LABEL="col1@example.com" -COL2_LABEL="col2@example.com" -COL3_LABEL="col3@example.com" -COL4_LABEL="col4@example.com" -COL5_LABEL="col5@example.com" - -if ${TWO_COL_SAME_CERT}; then - COL1_CN="org1@example.com" - COL2_CN="org2@example.com" - COL3_CN="org3@example.com" - COL4_CN="org4@example.com" - COL5_CN="org5@example.com" # in this case this var is not used actually. -fi - - -CODE_CHANGE_DIR="/home/edwardsb/repositories/hasan_medperf/examples/fl_post/fl" - -HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" - -cp -r $CODE_CHANGE_DIR/* $HOMEDIR - -cd $HOMEDIR - -mkdir mlcube_agg -mkdir mlcube_col1 -mkdir mlcube_col2 -mkdir mlcube_col3 -mkdir mlcube_col4 -mkdir mlcube_col5 - - - -cp -r ./mlcube/* ./mlcube -cp -r ./mlcube/* ./mlcube_agg -cp -r ./mlcube/* ./mlcube_col1 -cp -r ./mlcube/* ./mlcube_col2 -cp -r ./mlcube/* ./mlcube_col3 -cp -r ./mlcube/* ./mlcube_col4 -cp -r ./mlcube/* ./mlcube_col5 - -mkdir ./mlcube_agg/workspace/node_cert ./mlcube_agg/workspace/ca_cert -mkdir ./mlcube_col1/workspace/node_cert ./mlcube_col1/workspace/ca_cert -mkdir ./mlcube_col2/workspace/node_cert ./mlcube_col2/workspace/ca_cert -mkdir ./mlcube_col3/workspace/node_cert ./mlcube_col3/workspace/ca_cert -mkdir ./mlcube_col4/workspace/node_cert ./mlcube_col4/workspace/ca_cert -mkdir ./mlcube_col5/workspace/node_cert ./mlcube_col5/workspace/ca_cert -mkdir ca - -HOSTNAME_=$(hostname -A | cut -d " " -f 1) -# HOSTNAME_=$(hostname -I | cut -d " " -f 1) - - -# root ca -openssl genpkey -algorithm RSA -out ca/root.key -pkeyopt rsa_keygen_bits:3072 -openssl req -x509 -new -nodes -key ca/root.key -sha384 -days 36500 -out ca/root.crt \ - -subj "/DC=org/DC=simple/CN=Simple Root CA/O=Simple Inc/OU=Simple Root CA" - -# col1 -sed -i "/^commonName = /c\commonName = $COL1_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL1_CN" csr.conf -cd mlcube_col1/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd $HOMEDIR - -# col2 -sed -i "/^commonName = /c\commonName = $COL2_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL2_CN" csr.conf -cd mlcube_col2/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd $HOMEDIR - -# col3 -if ${TWO_COL_SAME_CERT}; then - never goes here cp mlcube_col2/workspace/node_cert/* mlcube_col3/workspace/node_cert - cp mlcube_col2/workspace/ca_cert/* mlcube_col3/workspace/ca_cert -else - sed -i "/^commonName = /c\commonName = $COL3_CN" csr.conf - sed -i "/^DNS\.1 = /c\DNS.1 = $COL3_CN" csr.conf - cd mlcube_col3/workspace/node_cert - openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 - openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client - openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf - rm csr.csr - cp ../../../ca/root.crt ../ca_cert/ - cd $HOMEDIR -fi - -# col4 -sed -i "/^commonName = /c\commonName = $COL4_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL4_CN" csr.conf -cd mlcube_col4/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd $HOMEDIR - -# col5 -sed -i "/^commonName = /c\commonName = $COL5_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $COL5_CN" csr.conf -cd mlcube_col5/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../../csr.conf -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd $HOMEDIR - - - -# agg -sed -i "/^commonName = /c\commonName = $HOSTNAME_" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $HOSTNAME_" csr.conf -cd mlcube_agg/workspace/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../../csr.conf -extensions v3_server -openssl x509 -req -in csr.csr -CA ../../../ca/root.crt -CAkey ../../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_server_crt -extfile ../../../csr.conf -rm csr.csr -cp ../../../ca/root.crt ../ca_cert/ -cd $HOMEDIR - -# aggregator_config -echo "address: $HOSTNAME_" >> mlcube_agg/workspace/aggregator_config.yaml -echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml - -# cols file -echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL4_LABEL: $COL4_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL5_LABEL: $COL5_CN" >>mlcube_agg/workspace/cols.yaml - -# for admin -ADMIN_CN="admin@example.com" - -mkdir ./for_admin -mkdir ./for_admin/node_cert - -sed -i "/^commonName = /c\commonName = $ADMIN_CN" csr.conf -sed -i "/^DNS\.1 = /c\DNS.1 = $ADMIN_CN" csr.conf -cd for_admin/node_cert -openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:3072 -openssl req -new -key key.key -out csr.csr -config ../../csr.conf -extensions v3_client -openssl x509 -req -in csr.csr -CA ../../ca/root.crt -CAkey ../../ca/root.key \ - -CAcreateserial -out crt.crt -days 36500 -sha384 -extensions v3_client_crt -extfile ../../csr.conf -rm csr.csr -mkdir ../ca_cert -cp -r ../../ca/root.crt ../ca_cert/root.crt -cd $HOMEDIR - -# THIS IS BRANDON'S CODE COPYING IN THE SAME DATA -mkdir mlcube_col1/workspace/labels -mkdir mlcube_col1/workspace/data - -mkdir mlcube_col2/workspace/labels -mkdir mlcube_col2/workspace/data - -mkdir mlcube_col3/workspace/labels -mkdir mlcube_col3/workspace/data - -mkdir mlcube_col4/workspace/labels -mkdir mlcube_col4/workspace/data - -mkdir mlcube_col5/workspace/labels -mkdir mlcube_col5/workspace/data - -# DATA_DIR="test_data_links_testforhasan" -# DATA_DIR="test_data_links_random_times_0" -# DATA_DIR="test_data_links" - -# this is the one I had success running on -#DATA_DIRS=test_data_small_from_hasan - -#SIZE="hundred" -#SUPPLEMENT="square" -#SUPPLEMENT="thresholdbrainsorted" -SUPPLEMENT="thresholdbrainandsquaresorted" -SIZE="thousand" - -DATA_DIR_1="test_${SIZE}_BraTS20_3${SUPPLEMENT}_0" -DATA_DIR_2="test_${SIZE}_BraTS20_3${SUPPLEMENT}_1" -DATA_DIR_3="test_${SIZE}_BraTS20_3${SUPPLEMENT}_2" -DATA_DIR_4="test_${SIZE}_BraTS20_3${SUPPLEMENT}_3" -DATA_DIR_5="test_${SIZE}_BraTS20_3${SUPPLEMENT}_4" - -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_1/labels/* mlcube_col1/workspace/labels -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_1/data/* mlcube_col1/workspace/data - -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_2/labels/* mlcube_col2/workspace/labels -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_2/data/* mlcube_col2/workspace/data - -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_3/labels/* mlcube_col3/workspace/labels -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_3/data/* mlcube_col3/workspace/data - -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_4/labels/* mlcube_col4/workspace/labels -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_4/data/* mlcube_col4/workspace/data - -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_5/labels/* mlcube_col5/workspace/labels -cp -r /raid/edwardsb/projects/RANO/$DATA_DIR_5/data/* mlcube_col5/workspace/data - -# wget https://storage.googleapis.com/medperf-storage/fltest29July/flpost_add29july.tar.gz I copied on spr01 into /home/edwardsb/repo_extras/hasan_medperperf_extras - -# aggregator additional files -mkdir mlcube_agg/workspace/additional_files -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_weights mlcube_agg/workspace/additional_files -# maybe I don't need the one immediately below (only for collaborators) -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_agg/workspace/additional_files - -# col1 additional files -mkdir mlcube_col1/workspace/additional_files -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col1/workspace/additional_files - -# col2 additional files -mkdir mlcube_col2/workspace/additional_files -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col2/workspace/additional_files - -# col3 additional files -mkdir mlcube_col3/workspace/additional_files -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col3/workspace/additional_files - -# col4 additional files -mkdir mlcube_col4/workspace/additional_files -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col4/workspace/additional_files - -# col5 additional files -mkdir mlcube_col5/workspace/additional_files -cp -r /home/edwardsb/repo_extras/hasan_medperf_extras/download_from_hasan/init_nnunet mlcube_col5/workspace/additional_files - - -# source /home/edwardsb/virtual/hasan_medperf/bin/activate diff --git a/examples/fl_post/fl/be_test.sh b/examples/fl_post/fl/be_test.sh deleted file mode 100755 index b4910d2d0..000000000 --- a/examples/fl_post/fl/be_test.sh +++ /dev/null @@ -1,57 +0,0 @@ -export HTTPS_PROXY= -export http_proxy= - -HOMEDIR="/raid/edwardsb/projects/RANO/hasan_medperf/examples/fl_post/fl" - -cd $HOMEDIR - -# generate plan and copy it to each node -GENERATE_PLAN_PLATFORM="docker" -AGG_PLATFORM="docker" -COL1_PLATFORM="docker" -COL2_PLATFORM="docker" -COL3_PLATFORM="docker" -COL4_PLATFORM="docker" -COL5_PLATFORM="docker" - -medperf --platform $GENERATE_PLAN_PLATFORM mlcube run --mlcube ./mlcube_agg --task generate_plan -mv ./mlcube_agg/workspace/plan/plan.yaml ./mlcube_agg/workspace -rm -r ./mlcube_agg/workspace/plan -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col1/workspace -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col2/workspace -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col3/workspace -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col4/workspace -cp ./mlcube_agg/workspace/plan.yaml ./mlcube_col5/workspace -cp ./mlcube_agg/workspace/plan.yaml ./for_admin - -# Run nodes -AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf --platform $COL2_PLATFORM --gpus=device=2 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf --platform $COL3_PLATFORM --gpus=device=3 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" -COL4="medperf --platform $COL4_PLATFORM --gpus=device=4 mlcube run --mlcube ./mlcube_col4 --task train -e MEDPERF_PARTICIPANT_LABEL=col4@example.com" -COL5="medperf --platform $COL5_PLATFORM --gpus=device=5 mlcube run --mlcube ./mlcube_col5 --task train -e MEDPERF_PARTICIPANT_LABEL=col5@example.com" - -# medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & - -# gnome-terminal -- bash -c "$AGG; bash" -# gnome-terminal -- bash -c "$COL1; bash" -# gnome-terminal -- bash -c "$COL2; bash" -# gnome-terminal -- bash -c "$COL3; bash" -rm agg.log col1.log col2.log col3.log col4.log col5.log -$AGG >>agg.log & -sleep 6 -$COL1 >>col1.log & -sleep 6 -$COL2 >>col2.log & -sleep 6 -$COL3 >> col3.log & -sleep 6 -$COL4 >> col4.log & -sleep 6 -$COL5 >> col5.log & - -wait - -# docker run --env MEDPERF_PARTICIPANT_LABEL=col1@example.com --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/data:/mlcube_io0:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/labels:/mlcube_io1:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/node_cert:/mlcube_io2:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/ca_cert:/mlcube_io3:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace:/mlcube_io4:ro --volume /home/hasan/work/medperf_ws/medperf/examples/fl/fl/mlcube_col1/workspace/logs:/mlcube_io5 -it --entrypoint bash mlcommons/medperf-fl:1.0.0 -# python /mlcube_project/mlcube.py train --data_path=/mlcube_io0 --labels_path=/mlcube_io1 --node_cert_folder=/mlcube_io2 --ca_cert_folder=/mlcube_io3 --plan_path=/mlcube_io4/plan.yaml --output_logs=/mlcube_io5 diff --git a/examples/fl_post/fl/intel_build.sh b/examples/fl_post/fl/intel_build.sh deleted file mode 100755 index 7f8cd19f7..000000000 --- a/examples/fl_post/fl/intel_build.sh +++ /dev/null @@ -1,22 +0,0 @@ -while getopts b flag; do - case "${flag}" in - b) BUILD_BASE="true" ;; - esac -done -BUILD_BASE="${BUILD_BASE:-false}" - -# copy over changes from be_Dockerfile to Dockerfile -cp ./project/be_Dockerfile ./project/Dockerfile - -if ${BUILD_BASE}; then - git clone https://github.com/hasan7n/openfl.git - cd openfl - git checkout be_enable_partial_epochs - docker build \ - --build-arg http_proxy="http://proxy-us.intel.com:912" \ - --build-arg https_proxy="http://proxy-us.intel.com:912" \ - -t local/openfl:local -f openfl-docker/Dockerfile.base . - cd .. - rm -rf openfl -fi -mlcube configure --mlcube ./mlcube -Pdocker.build_strategy=always -Pdocker.build_args="--build-arg http_proxy='http://proxy-us.intel.com:912' --build-arg https_proxy='http://proxy-us.intel.com:912'" From 0ed94e95c71c4f1be94db5e25a49dc4454019699 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 10:42:37 -0700 Subject: [PATCH 217/242] removing be_Dockerfile --- examples/fl_post/fl/project/be_Dockerfile | 29 ----------------------- 1 file changed, 29 deletions(-) delete mode 100644 examples/fl_post/fl/project/be_Dockerfile diff --git a/examples/fl_post/fl/project/be_Dockerfile b/examples/fl_post/fl/project/be_Dockerfile deleted file mode 100644 index 874266e3f..000000000 --- a/examples/fl_post/fl/project/be_Dockerfile +++ /dev/null @@ -1,29 +0,0 @@ -FROM local/openfl:local - -ENV LANG C.UTF-8 -ENV CUDA_VISIBLE_DEVICES="0" -# ENV http_proxy="http://proxy-us.intel.com:912" -# ENV https_proxy="http://proxy-us.intel.com:912" -ENV no_proxy=localhost,spr-gpu01.jf.intel.com - -ENV no_proxy_________________________________________________________________________="http://proxy-us.intel.com:912" - -# install project dependencies -RUN apt-get update && apt-get install --no-install-recommends -y git zlib1g-dev libffi-dev libgl1 libgtk2.0-dev gcc g++ - -RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121 - -COPY ./be_requirements.txt /mlcube_project/requirements.txt -RUN pip install --no-cache-dir -r /mlcube_project/requirements.txt - -# Create similar env with cuda118 -RUN apt-get update && apt-get install python3.10-venv -y -RUN python -m venv /cuda118 -RUN /cuda118/bin/pip install --no-cache-dir /openfl -RUN /cuda118/bin/pip install --no-cache-dir torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --extra-index-url https://download.pytorch.org/whl/cu118 -RUN /cuda118/bin/pip install --no-cache-dir -r /mlcube_project/requirements.txt - -# Copy mlcube project folder -COPY . /mlcube_project - -ENTRYPOINT ["sh", "/mlcube_project/entrypoint.sh"] From 317cf8733d1b30b899c89e2c4d46c2a63e128bbe Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 10:49:24 -0700 Subject: [PATCH 218/242] doc string --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index e539a8913..bdf6a0393 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -84,9 +84,9 @@ def train_nnunet(actual_max_num_epochs, """ actual_max_num_epochs (int): Provides the number of epochs intended to be trained (this needs to be held constant outside of individual calls to this function during with max_num_epochs is set to one more than the current epoch) - fl_round (int): Federated round, equal to the epoch used for the model (lr scheduling) + fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) val_epoch (bool) : Will validation be performed - train_epoch (bool) : Will training run (rather than val only) makes lr step and epoch increment + train_epoch (bool) : Will training run (rather than val only) train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. train_cutoff_part (float): Portion of train_val_cutoff going to training val_cutoff_part (float): Portion of train_val_cutoff going to val From e27aaf0fd29b4d65cd7539785c9936a868258d00 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 10:54:39 -0700 Subject: [PATCH 219/242] doc changes and removing validation_only param as we control that another way --- examples/fl_post/fl/project/src/nnunet_v1.py | 30 ++++++-------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index bdf6a0393..89165ab43 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -65,7 +65,6 @@ def train_nnunet(actual_max_num_epochs, task='Task543_FakePostOpp_More', fold='0', continue_training=True, - validation_only=False, c=False, p=plans_param, use_compressed_data=False, @@ -88,11 +87,8 @@ def train_nnunet(actual_max_num_epochs, val_epoch (bool) : Will validation be performed train_epoch (bool) : Will training run (rather than val only) train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. - train_cutoff_part (float): Portion of train_val_cutoff going to training - val_cutoff_part (float): Portion of train_val_cutoff going to val task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" - validation_only: use this if you want to only run the validation c: use this if you want to continue a training p: plans identifier. Only change this if you created a custom experiment planner use_compressed_data: "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " @@ -146,7 +142,6 @@ def __init__(self, **kwargs): fold = args.fold network = args.network network_trainer = args.network_trainer - validation_only = args.validation_only plans_identifier = args.p find_lr = args.find_lr disable_postprocessing_on_folds = args.disable_postprocessing_on_folds @@ -223,7 +218,7 @@ def __init__(self, **kwargs): ) - trainer.initialize(not validation_only) + trainer.initialize(True) if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": trainer.save_checkpoint( @@ -235,22 +230,15 @@ def __init__(self, **kwargs): if find_lr: trainer.find_lr(num_iters=self.actual_max_num_epochs) else: - if not validation_only: - if args.continue_training: - # -c was set, continue a previous training and ignore pretrained weights - trainer.load_latest_checkpoint() - elif (not args.continue_training) and (args.pretrained_weights is not None): - # we start a new training. If pretrained_weights are set, use them - load_pretrained_weights(trainer.network, args.pretrained_weights) - else: - # new training without pretraine weights, do nothing - pass - else: - # if valbest: - # trainer.load_best_checkpoint(train=False) - # else: - # trainer.load_final_checkpoint(train=False) + if args.continue_training: + # -c was set, continue a previous training and ignore pretrained weights trainer.load_latest_checkpoint() + elif (not args.continue_training) and (args.pretrained_weights is not None): + # we start a new training. If pretrained_weights are set, use them + load_pretrained_weights(trainer.network, args.pretrained_weights) + else: + # new training without pretraine weights, do nothing + pass # we want latest checkoint only (not best or any intermediate) trainer.save_final_checkpoint = ( From f16f011525b9386260b0ae51986c8cba4c276ff5 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 11:00:13 -0700 Subject: [PATCH 220/242] remiving epochs variable since nnunet trainer run_training is hard coded to train one epoch --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 89165ab43..d3a2719ca 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -125,8 +125,6 @@ def train_nnunet(actual_max_num_epochs, "Optional. Beta. Use with caution." disable_next_stage_pred: If True, do not predict next stage """ - # hard coded, as internal trainer.run_training is currently written to run only one epoch - epochs = 1 class Arguments(): def __init__(self, **kwargs): @@ -256,7 +254,7 @@ def __init__(self, **kwargs): True # if false it will not store/overwrite _latest but separate files each ) - trainer.max_num_epochs = fl_round + epochs + trainer.max_num_epochs = fl_round + 1 trainer.epoch = fl_round # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data From f5ac88f396d3243432b5819c0a7173111781140f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 11:37:25 -0700 Subject: [PATCH 221/242] docstring --- examples/fl_post/fl/project/src/nnunet_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index d3a2719ca..299dfb1ce 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -81,8 +81,8 @@ def train_nnunet(actual_max_num_epochs, pretrained_weights=None): """ - actual_max_num_epochs (int): Provides the number of epochs intended to be trained - (this needs to be held constant outside of individual calls to this function during with max_num_epochs is set to one more than the current epoch) + actual_max_num_epochs (int): Provides the number of epochs intended to be trained over the course of the whole federation (for lr scheduling) + (this needs to be held constant outside of individual calls to this function so that the lr is consistetly scheduled) fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) val_epoch (bool) : Will validation be performed train_epoch (bool) : Will training run (rather than val only) From da0b84bf135a2a5275df4aa225e8672ed65c11d9 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 11:45:33 -0700 Subject: [PATCH 222/242] dockstring and fl plan fixing rounds to train for both agg and task runner blocks --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 4 ++-- examples/fl_post/fl/project/src/nnunet_v1.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 6456c3a02..02eff0bc6 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,7 +5,7 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : 300 + rounds_to_train : &rounds_to_train 2 admins_endpoints_mapping: testfladmin@example.com: - GetExperimentStatus @@ -63,7 +63,7 @@ task_runner : device : cuda gpu_num_string : '0' nnunet_task : Task537_FLPost - actual_max_num_epochs : 300 + actual_max_num_epochs : *rounds_to_train network : defaults : plan/defaults/network.yaml diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 299dfb1ce..2e5df028b 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -86,7 +86,6 @@ def train_nnunet(actual_max_num_epochs, fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) val_epoch (bool) : Will validation be performed train_epoch (bool) : Will training run (rather than val only) - train_val_cutoff (int): Total time (in seconds) limit to use in approximating a restriction to training and validation activities. task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" c: use this if you want to continue a training From 2167e7e7f66b4b260651a3e099715ce82a26e308 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Mon, 28 Oct 2024 11:52:55 -0700 Subject: [PATCH 223/242] dockstring --- examples/fl_post/fl/project/src/runner_nnunetv1.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 96cb257cc..cd2c4e61a 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -147,8 +147,6 @@ def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time, # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" - # epochs is not used, inside function is hard coded for epochs=1 - self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) From c5161c6bcd00ddae9d77e8c09721e62b0fe79b1c Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 29 Oct 2024 06:40:06 -0700 Subject: [PATCH 224/242] removing comment and test piece in config --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 2 -- examples/fl_post/fl/project/nnunet_model_setup.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 02eff0bc6..d12643dd6 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -12,8 +12,6 @@ aggregator : - SetStragglerCuttoffTime - SetDynamicTaskArg - GetDynamicTaskArg - col1@example.com: - - SetDynamicTaskArg dynamictaskargs: &dynamictaskargs train: diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 2712ccc08..e444b727d 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -68,13 +68,6 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one """ - # Removing 2D data is not longer needed since we set "-pl2d None during plan and preprocessing call" - # TODO: remove this comment once tested - """ - if network != '2d': - delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) - """ - # get or create architecture info model_folder = get_model_folder(network=network, From e05c01dfd5f09848d57385941bcabcd2121a5b7e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 29 Oct 2024 08:41:05 -0700 Subject: [PATCH 225/242] removing unused function and providing default args --- .../fl_post/fl/project/src/runner_nnunetv1.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index cd2c4e61a..3191d9a57 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -143,7 +143,7 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): return epoch - def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time, train_cutoff_time, train_completion_dampener, **kwargs): + def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time=np.inf, train_cutoff_time=np.inf, train_completion_dampener=0.0, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" @@ -193,16 +193,6 @@ def validate(self, col_name, round_num, input_tensor_dict, val_cutoff_time=np.in # TODO: Figure out the right name to use for this method and the default assigner """Perform validation.""" - def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): - hash_1 = np.sum([np.mean(np.array(_value)) for _value in td_1.values()]) - hash_2 = np.sum([np.mean(np.array(_value)) for _value in td_2.values()]) - delta = np.abs(hash_1 - hash_2) - if verbose: - print(f"The tensor dict comparison {tag} resulted in delta: {delta} (accepted error: {epsilon}).") - if delta > epsilon: - raise ValueError(f"The tensor dict comparison {tag} failed with delta: {delta} against an accepted error of: {epsilon}.") - - if not from_checkpoint: self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint @@ -243,9 +233,7 @@ def compare_tensor_dicts(td_1, td_2, tag="", epsilon=0.1, verbose=True): 'val_eval_C4': this_val_eval_metrics_C4} else: checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) - # double check: uncomment below for testing - # compare_tensor_dicts(td_1=input_tensor_dict,td_2=checkpoint_dict['state_dict'], tag="checkpoint VS fromOpenFL") - + all_tr_losses, \ all_val_losses, \ all_val_losses_tr_mode, \ From c396483a7230757a6733a2b93a2ed9d2dd501063 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Tue, 29 Oct 2024 13:31:47 -0700 Subject: [PATCH 226/242] removing a line of print output --- examples/fl_post/fl/project/nnunet_model_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index e444b727d..a647a2f44 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -94,7 +94,7 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path']) shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path']) else: - print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\ncol_paths were: {col_paths}\n\n") + print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\n") shutil.copy(src=plans_path,dst=col_paths['plans_path']) shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path']) shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path']) From eb2813107b879a11bcdc4cf010c6cb283fc6971a Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 30 Oct 2024 19:10:34 +0000 Subject: [PATCH 227/242] update testing scripts --- examples/fl_post/fl/README.md | 32 +++++++++++++++++-- examples/fl_post/fl/build.sh | 2 +- examples/fl_post/fl/mlcube/mlcube.yaml | 2 +- .../fl/mlcube/workspace/training_config.yaml | 10 +++--- examples/fl_post/fl/setup_test_no_docker.sh | 11 ++++--- examples/fl_post/fl/test.sh | 32 +++++++++++++++---- 6 files changed, 69 insertions(+), 20 deletions(-) diff --git a/examples/fl_post/fl/README.md b/examples/fl_post/fl/README.md index 918f483e3..602746047 100644 --- a/examples/fl_post/fl/README.md +++ b/examples/fl_post/fl/README.md @@ -1,6 +1,34 @@ -# How to run tests +# How to run tests (see next section for a detailed guide) -- Run `setup_test.sh` just once to create certs and download required data. +- Run `setup_test_no_docker.sh` just once to create certs and download required data. - Run `test.sh` to start the aggregator and three collaborators. - Run `clean.sh` to be able to rerun `test.sh` freshly. - Run `setup_clean.sh` to clear what has been generated in step 1. + +## Detailed Guide + +- Go to your medperf repo and checkout the required branch. +- Have medperf virtual environment activated (and medperf installed) +- run: `setup_test_no_docker.sh` to setup the test (you should `setup_clean.sh` if you already ran this before you run it again). +- run: `test.sh --d1 absolute_path --l2 absolute_path ...` to run the test + - data paths can be specified in the command. --dn is for data path of collaborator n, --ln is for labels_path of collaborator n. + - make sure gpu IDs are set as expected in `test.sh` script. +- to stop: `CTRL+C` in the terminal where you ran `test.sh`, then, `docker container ls`, then take the container IDs, then `docker container stop `, to stop relevant running containers (to identify containers to stop, they should have an IMAGE field same name as the one configured in docker image field in `mlcube.yaml`). You can at the end use `docker container prune` to delete all stopped containers if you want (not necessary). +- To rerun: you should first run `sh clean.sh`, then `sh test.sh` again. + +## What to do when you want to + +- change port: either change `setup_test_no_docker.sh` then clean setup and run setup again, or, go to `mlcube_agg/workspace/aggregator_config.yaml` and modify the file directly. +- Change address: change `setup_test_no_docker.sh` then clean setup and run setup again. (since the cert needs to be generated) +- change training_config: modify `mlcube/workspace/training_config.yaml` then run `sync.sh`. +- use custom data paths: pass data paths when running `test.sh` (`--d1, --d2, --l1, ...`) +- change weights: modify `mlcube/workspace/additional_files` then run `sync.sh`. +- fl_admin? connect to container and run fx commands. make sure a colab is an admin (to be detailed later) + +- to use three collaborators instead of two: + - go to `mlcube_agg/workspace/cols.yaml` and modify the list by adding col3. + - in `test.sh`, uncomment col3's run command. + +## to rebuild + +sh build.sh (or with -b if you want to rebuild the openfl base as well. Configure `build.sh` to change how openfl base is built) diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index 08cdbb20c..65b2c633f 100755 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout 7c9d4e7039f51014a4f7b3bedf5e2c7f1d353e68 + git checkout d0f6df8ea91e0eaaeabf0691caf0286162df5bd7 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl_post/fl/mlcube/mlcube.yaml b/examples/fl_post/fl/mlcube/mlcube.yaml index 7fbe19c0a..835e39ea3 100644 --- a/examples/fl_post/fl/mlcube/mlcube.yaml +++ b/examples/fl_post/fl/mlcube/mlcube.yaml @@ -9,7 +9,7 @@ platform: docker: gpu_args: "--shm-size 12g" # Image name - image: local/tmp:0.0.0 + image: mlcommons/rano-fl:30-oct-2024 # Docker build context relative to $MLCUBE_ROOT. Default is `build`. build_context: "../project" # Docker file name within docker build context, default is `Dockerfile`. diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index d12643dd6..0173431b6 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,9 +5,9 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : &rounds_to_train 2 + rounds_to_train : &rounds_to_train 20 admins_endpoints_mapping: - testfladmin@example.com: + col1@example.com: - GetExperimentStatus - SetStragglerCuttoffTime - SetDynamicTaskArg @@ -19,7 +19,7 @@ aggregator : admin_settable: True min: 10 # 10 seconds max: 86400 # one day - value: 86400 # one day + value: 300 # one day val_cutoff_time: admin_settable: True min: 10 # 10 seconds @@ -115,5 +115,5 @@ compression_pipeline : straggler_handling_policy : template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling settings : - straggler_cutoff_time : 600 - minimum_reporting : 5 \ No newline at end of file + straggler_cutoff_time : 1200 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl_post/fl/setup_test_no_docker.sh b/examples/fl_post/fl/setup_test_no_docker.sh index 91ea21ec7..0e354cbb5 100644 --- a/examples/fl_post/fl/setup_test_no_docker.sh +++ b/examples/fl_post/fl/setup_test_no_docker.sh @@ -94,8 +94,8 @@ echo "address: $HOSTNAME_" >>mlcube_agg/workspace/aggregator_config.yaml echo "port: 50273" >>mlcube_agg/workspace/aggregator_config.yaml # cols file +echo "$COL1_LABEL: $COL1_CN" >>mlcube_agg/workspace/cols.yaml echo "$COL2_LABEL: $COL2_CN" >>mlcube_agg/workspace/cols.yaml -echo "$COL3_LABEL: $COL3_CN" >>mlcube_agg/workspace/cols.yaml # for admin ADMIN_CN="testfladmin@example.com" @@ -141,13 +141,14 @@ rm -rf small_test_data3 cd ../../ # weights setup -cd mlcube_agg/workspace +cd mlcube/workspace mkdir additional_files cd additional_files wget https://storage.googleapis.com/medperf-storage/fltest29July/flpost_add29july.tar.gz tar -xf flpost_add29july.tar.gz rm flpost_add29july.tar.gz cd ../../../ -cp -r mlcube_agg/workspace/additional_files mlcube_col1/workspace/additional_files -cp -r mlcube_agg/workspace/additional_files mlcube_col2/workspace/additional_files -cp -r mlcube_agg/workspace/additional_files mlcube_col3/workspace/additional_files +cp -r mlcube/workspace/additional_files mlcube_agg/workspace/additional_files +cp -r mlcube/workspace/additional_files mlcube_col1/workspace/additional_files +cp -r mlcube/workspace/additional_files mlcube_col2/workspace/additional_files +cp -r mlcube/workspace/additional_files mlcube_col3/workspace/additional_files diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 9463c56a1..78f7ad7c6 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -1,7 +1,25 @@ +while getopts d1:l1:d2:l2:d3:l3 flag; do + case "${flag}" in + d1) COL1_DATA=${OPTARG} ;; + l1) COL1_LABELS=${OPTARG} ;; + d2) COL2_DATA=${OPTARG} ;; + l2) COL2_LABELS=${OPTARG} ;; + d3) COL3_DATA=${OPTARG} ;; + l3) COL3_LABELS=${OPTARG} ;; + esac +done + +COL1_DATA="${COL1_DATA:-$PWD/mlcube_col1/workspace/data}" +COL1_LABELS="${COL1_LABELS:-$PWD/mlcube_col1/workspace/labels}" +COL2_DATA="${COL2_DATA:-$PWD/mlcube_col2/workspace/data}" +COL2_LABELS="${COL2_LABELS:-$PWD/mlcube_col2/workspace/labels}" +COL3_DATA="${COL3_DATA:-$PWD/mlcube_col3/workspace/data}" +COL3_LABELS="${COL3_LABELS:-$PWD/mlcube_col3/workspace/labels}" + # generate plan and copy it to each node GENERATE_PLAN_PLATFORM="docker" AGG_PLATFORM="docker" -COL1_PLATFORM="singularity" +COL1_PLATFORM="docker" COL2_PLATFORM="docker" COL3_PLATFORM="docker" @@ -15,9 +33,9 @@ cp ./mlcube_agg/workspace/plan.yaml ./for_admin # Run nodes AGG="medperf --platform $AGG_PLATFORM mlcube run --mlcube ./mlcube_agg --task start_aggregator -P 50273" -COL1="medperf --platform $COL1_PLATFORM --gpus=1 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com" -COL2="medperf --platform $COL2_PLATFORM --gpus=device=0 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com" -COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com" +COL1="medperf --platform $COL1_PLATFORM --gpus=device=0 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com --params data_path=$COL1_DATA,labels_path=$COL1_LABELS" +COL2="medperf --platform $COL2_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col2 --task train -e MEDPERF_PARTICIPANT_LABEL=col2@example.com --params data_path=$COL2_DATA,labels_path=$COL2_LABELS" +COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./mlcube_col3 --task train -e MEDPERF_PARTICIPANT_LABEL=col3@example.com --params data_path=$COL3_DATA,labels_path=$COL3_LABELS" # medperf --gpus=device=2 mlcube run --mlcube ./mlcube_col1 --task train -e MEDPERF_PARTICIPANT_LABEL=col1@example.com >>col1.log & @@ -28,9 +46,11 @@ COL3="medperf --platform $COL3_PLATFORM --gpus=device=1 mlcube run --mlcube ./ml rm agg.log col1.log col2.log col3.log $AGG >>agg.log & sleep 6 -$COL2 >>col2.log & +$COL1 >>col1.log & sleep 6 -$COL3 >>col3.log & +$COL2 >>col2.log & +# sleep 6 +# $COL3 >>col3.log & # sleep 6 # $COL2 >> col2.log & # sleep 6 From 6923731001a7c6f5a45a053ebe590404681599a2 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Wed, 30 Oct 2024 19:41:56 +0000 Subject: [PATCH 228/242] update test scripts again --- examples/fl_post/fl/README.md | 8 +++---- examples/fl_post/fl/test.sh | 45 ++++++++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/examples/fl_post/fl/README.md b/examples/fl_post/fl/README.md index 602746047..a890dff42 100644 --- a/examples/fl_post/fl/README.md +++ b/examples/fl_post/fl/README.md @@ -1,7 +1,7 @@ # How to run tests (see next section for a detailed guide) - Run `setup_test_no_docker.sh` just once to create certs and download required data. -- Run `test.sh` to start the aggregator and three collaborators. +- Run `test.sh ` to start the aggregator and three collaborators. (requires BASH, see next section) - Run `clean.sh` to be able to rerun `test.sh` freshly. - Run `setup_clean.sh` to clear what has been generated in step 1. @@ -9,12 +9,12 @@ - Go to your medperf repo and checkout the required branch. - Have medperf virtual environment activated (and medperf installed) -- run: `setup_test_no_docker.sh` to setup the test (you should `setup_clean.sh` if you already ran this before you run it again). -- run: `test.sh --d1 absolute_path --l2 absolute_path ...` to run the test +- run: `sh setup_test_no_docker.sh` to setup the test (you should `sh setup_clean.sh` if you already ran this before you run it again). +- run: `bash test.sh --d1 absolute_path --l2 absolute_path ...` to run the test - data paths can be specified in the command. --dn is for data path of collaborator n, --ln is for labels_path of collaborator n. - make sure gpu IDs are set as expected in `test.sh` script. - to stop: `CTRL+C` in the terminal where you ran `test.sh`, then, `docker container ls`, then take the container IDs, then `docker container stop `, to stop relevant running containers (to identify containers to stop, they should have an IMAGE field same name as the one configured in docker image field in `mlcube.yaml`). You can at the end use `docker container prune` to delete all stopped containers if you want (not necessary). -- To rerun: you should first run `sh clean.sh`, then `sh test.sh` again. +- To rerun: you should first run `sh clean.sh`, then `bash test.sh ...` again. ## What to do when you want to diff --git a/examples/fl_post/fl/test.sh b/examples/fl_post/fl/test.sh index 78f7ad7c6..f37658ed1 100755 --- a/examples/fl_post/fl/test.sh +++ b/examples/fl_post/fl/test.sh @@ -1,12 +1,41 @@ -while getopts d1:l1:d2:l2:d3:l3 flag; do - case "${flag}" in - d1) COL1_DATA=${OPTARG} ;; - l1) COL1_LABELS=${OPTARG} ;; - d2) COL2_DATA=${OPTARG} ;; - l2) COL2_LABELS=${OPTARG} ;; - d3) COL3_DATA=${OPTARG} ;; - l3) COL3_LABELS=${OPTARG} ;; +COL1_DATA="" +COL1_LABELS="" +COL2_DATA="" +COL2_LABELS="" +COL3_DATA="" +COL3_LABELS="" +while [[ "$#" -gt 0 ]]; do + case $1 in + --d1) + COL1_DATA="$2" + shift + ;; + --d2) + COL2_DATA="$2" + shift + ;; + --d3) + COL3_DATA="$2" + shift + ;; + --l1) + COL1_LABELS="$2" + shift + ;; + --l2) + COL2_LABELS="$2" + shift + ;; + --l3) + COL3_LABELS="$2" + shift + ;; + *) + echo "Unknown parameter: $1" + exit 1 + ;; esac + shift done COL1_DATA="${COL1_DATA:-$PWD/mlcube_col1/workspace/data}" From 9e18f82f0a982c002f829082456dd5d0878a4239 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 31 Oct 2024 20:41:23 +0000 Subject: [PATCH 229/242] restart training on failure --- cli/medperf/commands/dataset/dataset.py | 9 ++++- cli/medperf/commands/dataset/train.py | 50 ++++++++++++++++++++----- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index 375f141ec..9a3ab2f3a 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -157,10 +157,17 @@ def train( overwrite: bool = typer.Option( False, "--overwrite", help="Overwrite outputs if present" ), + restart_on_failure: bool = typer.Option( + False, + "--restart_on_failure", + help="Keep restarting failing training processes until Keyboard interrupt", + ), approval: bool = typer.Option(False, "-y", help="Skip approval step"), ): """Runs training""" - TrainingExecution.run(training_exp_id, data_uid, overwrite, approval) + TrainingExecution.run( + training_exp_id, data_uid, overwrite, approval, restart_on_failure + ) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/dataset/train.py b/cli/medperf/commands/dataset/train.py index 867b9cc27..08105f26c 100644 --- a/cli/medperf/commands/dataset/train.py +++ b/cli/medperf/commands/dataset/train.py @@ -3,7 +3,12 @@ from medperf.account_management.account_management import get_medperf_user_data from medperf.entities.ca import CA from medperf.entities.event import TrainingEvent -from medperf.exceptions import CleanExit, InvalidArgumentError, MedperfException +from medperf.exceptions import ( + CleanExit, + ExecutionError, + InvalidArgumentError, + MedperfException, +) from medperf.entities.training_exp import TrainingExp from medperf.entities.dataset import Dataset from medperf.entities.cube import Cube @@ -25,22 +30,36 @@ def run( data_uid: int, overwrite: bool = False, approved: bool = False, + restart_on_failure: bool = False, ): """Starts the aggregation server of a training experiment Args: training_exp_id (int): Training experiment UID. """ + if restart_on_failure: + approved = True + overwrite = True execution = cls(training_exp_id, data_uid, overwrite, approved) - execution.prepare() - execution.validate() - execution.check_existing_outputs() - execution.prepare_plan() - execution.prepare_pki_assets() - execution.confirm_run() - with config.ui.interactive(): - execution.prepare_training_cube() - execution.run_experiment() + if restart_on_failure: + execution.confirm_restart_on_failure() + + while True: + execution.prepare() + execution.validate() + execution.check_existing_outputs() + execution.prepare_plan() + execution.prepare_pki_assets() + execution.confirm_run() + with config.ui.interactive(): + execution.prepare_training_cube() + try: + execution.run_experiment() + break + except ExecutionError as e: + print(str(e)) + if not restart_on_failure: + break def __init__( self, training_exp_id: int, data_uid: int, overwrite: bool, approved: bool @@ -51,6 +70,17 @@ def __init__( self.ui = config.ui self.approved = approved + def confirm_restart_on_failure(self): + msg = ( + "You chose to restart on failure. This means that the training process" + " will automatically restart, without your approval, even if training configuration" + " changes from the server side. Do you confirm? [Y/n] " + ) + if not approval_prompt(msg): + raise CleanExit( + "Training cancelled. Rerun without the --restart_on_failure flag." + ) + def prepare(self): self.training_exp = TrainingExp.get(self.training_exp_id) self.ui.print(f"Training Execution: {self.training_exp.name}") From 11fd6e766d06494276194141803e78f3d1769981 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 31 Oct 2024 22:58:06 +0000 Subject: [PATCH 230/242] Revert "Merge pull request #7 from hasan7n/be_enable_partial_epochs" This reverts commit 128e28b8d4f7fafc4dab8a676f4c11982893921c, reversing changes made to 26b4337f431921cfe559c6ca10ca17998bbb3850. --- .../fl/mlcube/workspace/training_config.yaml | 154 ++++-------- .../fl_post/fl/project/nnunet_data_setup.py | 16 +- .../fl_post/fl/project/nnunet_model_setup.py | 61 ++++- examples/fl_post/fl/project/nnunet_setup.py | 5 +- .../fl/project/src/nnunet_dummy_dataloader.py | 2 +- examples/fl_post/fl/project/src/nnunet_v1.py | 138 +++++------ .../fl_post/fl/project/src/runner_nnunetv1.py | 223 +++++++----------- .../fl_post/fl/project/src/runner_pt_chkpt.py | 27 +-- 8 files changed, 272 insertions(+), 354 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 0173431b6..8a73fe439 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -1,119 +1,69 @@ -aggregator : - defaults : plan/defaults/aggregator.yaml - template : openfl.component.Aggregator - settings : - init_state_path : save/fl_post_two_init.pbuf - best_state_path : save/fl_post_two_best.pbuf - last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : &rounds_to_train 20 +aggregator: + defaults: plan/defaults/aggregator.yaml + template: openfl.component.Aggregator + settings: + init_state_path: save/fl_post_two_init.pbuf + best_state_path: save/fl_post_two_best.pbuf + last_state_path: save/fl_post_two_last.pbuf + rounds_to_train: 2 admins_endpoints_mapping: col1@example.com: - GetExperimentStatus - SetStragglerCuttoffTime - - SetDynamicTaskArg - - GetDynamicTaskArg - dynamictaskargs: &dynamictaskargs - train: - train_cutoff_time: - admin_settable: True - min: 10 # 10 seconds - max: 86400 # one day - value: 300 # one day - val_cutoff_time: - admin_settable: True - min: 10 # 10 seconds - max: 86400 # one day - value: 86400 # one day - train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0 - admin_settable: True - min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) - max: 1.0 # leaves completion rates as is - value: 1.0 +collaborator: + defaults: plan/defaults/collaborator.yaml + template: openfl.component.Collaborator + settings: + delta_updates: false + opt_treatment: CONTINUE_LOCAL - aggregated_model_validation: - val_cutoff_time: - admin_settable: True - min: 10 # 10 seconds - max: 86400 # one day - value: 86400 # one day - - -collaborator : - defaults : plan/defaults/collaborator.yaml - template : openfl.component.Collaborator - settings : - delta_updates : false - opt_treatment : CONTINUE_LOCAL - dynamictaskargs: *dynamictaskargs - -data_loader : - defaults : plan/defaults/data_loader.yaml - template : src.nnunet_dummy_dataloader.NNUNetDummyDataLoader - settings : - p_train : 0.8 +data_loader: + defaults: plan/defaults/data_loader.yaml + template: src.nnunet_dummy_dataloader.NNUNetDummyDataLoader + settings: + p_train: 0.8 # TODO: make checkpoint-only truly generic and create the task runner within src -task_runner : - defaults : plan/defaults/task_runner.yaml - template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner - settings : - device : cuda - gpu_num_string : '0' - nnunet_task : Task537_FLPost - actual_max_num_epochs : *rounds_to_train - -network : - defaults : plan/defaults/network.yaml +task_runner: + defaults: plan/defaults/task_runner.yaml + template: src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner + settings: + device: cuda + gpu_num_string: "0" + nnunet_task: Task537_FLPost + +network: + defaults: plan/defaults/network.yaml settings: {} -assigner : - defaults : plan/defaults/assigner.yaml - template : openfl.component.assigner.DynamicRandomGroupedAssigner - settings : - task_groups : - - name : train_and_validate - percentage : 1.0 - tasks : - - aggregated_model_validation +assigner: + defaults: plan/defaults/assigner.yaml + template: openfl.component.assigner.DynamicRandomGroupedAssigner + settings: + task_groups: + - name: train_and_validate + percentage: 1.0 + tasks: + # - aggregated_model_validation - train - - locally_tuned_model_validation + # - locally_tuned_model_validation -tasks : - defaults : plan/defaults/tasks_torch.yaml - aggregated_model_validation: - function : validate - kwargs : - metrics : - - val_eval - - val_eval_C1 - - val_eval_C2 - - val_eval_C3 - - val_eval_C4 - apply : global +tasks: + defaults: plan/defaults/tasks_torch.yaml train: - function : train - kwargs : - metrics : + function: train + kwargs: + metrics: - train_loss - epochs : 1 - locally_tuned_model_validation: - function : validate - kwargs : - metrics : - val_eval - - val_eval_C1 - - val_eval_C2 - - val_eval_C3 - - val_eval_C4 - apply : local - from_checkpoint: true + epochs: 1 -compression_pipeline : - defaults : plan/defaults/compression_pipeline.yaml +compression_pipeline: + defaults: plan/defaults/compression_pipeline.yaml -straggler_handling_policy : - template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling - settings : - straggler_cutoff_time : 1200 - minimum_reporting : 2 \ No newline at end of file +straggler_handling_policy: + template: openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings: + straggler_cutoff_time: 600 + minimum_reporting: 2 diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index db3894e9d..3f24c9515 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -246,8 +246,11 @@ def setup_fl_data(postopp_pardir, should be run using a virtual environment that has nnunet version 1 installed. args: - postopp_pardir(str) : Parent directory for postopp data. - This directory should have 'data' and 'labels' subdirectories, with structure: + postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be + equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, + the samples within that single directory will be used to create num_insititutions shards. + If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. + Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: ├── data │ ├── AAAC_0 │ │ ├── 2008.03.30 @@ -295,7 +298,7 @@ def setup_fl_data(postopp_pardir, │ └── AAAC_extra_2008.12.10_final_seg.nii.gz └── report.yaml - three_digit_task_num(str): Should start with '5'. + three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. task_name(str) : Any string task name. percent_train(float) : what percent of data is put into the training data split (rest to val) split_logic(str) : Determines how train/val split is performed @@ -333,6 +336,7 @@ def setup_fl_data(postopp_pardir, # Track the subjects and timestamps for each shard subject_to_timestamps = {} + print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n") for postopp_subject_dir in all_subjects: subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, postopp_data_dirpath=postopp_data_dirpath, @@ -352,12 +356,12 @@ def setup_fl_data(postopp_pardir, # Now call the os process to preprocess the data print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") if plans_path: - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl2d", "None", "--verify_dataset_integrity"]) - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-pl2d", "None", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) plans_identifier_for_model_writing = shared_plans_identifier else: # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) plans_identifier_for_model_writing = local_plans_identifier # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index a647a2f44..4ebd1f9e7 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -7,13 +7,12 @@ def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): os.environ['CUDA_VISIBLE_DEVICES']=cuda_device - print(f"###########\nStarting training a single epoch for task: {task}\n") - # Function below is now hard coded for a single epoch of training. - train_nnunet(actual_max_num_epochs=1000, - fl_round=current_epoch, - network=network, + print(f"###########\nStarting training for task: {task}\n") + train_nnunet(epochs=1, + current_epoch = current_epoch, + network = network, task=task, - network_trainer=network_trainer, + network_trainer = network_trainer, fold=fold, continue_training=continue_training, p=plans_identifier) @@ -61,13 +60,61 @@ def delete_2d_data(network, task, plans_identifier): print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") shutil.rmtree(data_dir_2d) +""" +def normalize_architecture(reference_plan_path, target_plan_path): + + # Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path + # NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more + likely leed to an exception. + + + + assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features', + 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', + 'transpose_backward', 'preprocessor_name', 'conv_per_stage', 'data_identifier'] + copy_over_keys = ['plans_per_stage'] + nullify_keys = ['original_spacings', 'original_sizes'] + leave_alone_keys = ['list_of_npz_files', 'preprocessed_data_folder', 'dataset_properties'] + + + # check I got all keys here + assert set(copy_over_keys).union(set(assert_same_keys)).union(set(nullify_keys)).union(set(leave_alone_keys)) == set(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage']) + + def get_pickle_obj(path): + with open(path, 'rb') as _file: + plan= pkl.load(_file) + return plan + + def write_pickled_obj(obj, path): + with open(path, 'wb') as _file: + pkl.dump(obj, _file) + + reference_plan = get_pickle_obj(path=reference_plan_path) + target_plan = get_pickle_obj(path=target_plan_path) + + for key in assert_same_keys: + if reference_plan[key] != target_plan[key]: + raise ValueError(f"normalize architecture failed since the reference and target plans differed in at least key: {key}") + for key in copy_over_keys: + target_plan[key] = reference_plan[key] + for key in nullify_keys: + target_plan[key] = None + # leave alone keys are left alone :) + + # write back to target plan + write_pickled_obj(obj=target_plan, path=target_plan_path) +""" def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'): """ Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one """ + # Remove 2D data and 2D data info if appropriate + if network != '2d': + delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) + # get or create architecture info model_folder = get_model_folder(network=network, @@ -94,7 +141,7 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path']) shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path']) else: - print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\n") + print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\ncol_paths were: {col_paths}\n\n") shutil.copy(src=plans_path,dst=col_paths['plans_path']) shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path']) shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path']) diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 81f6787d8..0106b8f95 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -28,7 +28,7 @@ def main(postopp_pardir, plans_path=None, local_plans_identifier=local_plans_identifier, shared_plans_identifier=shared_plans_identifier, - overwrite_nnunet_datadirs=True, + overwrite_nnunet_datadirs=False, timestamp_selection='all', cuda_device='0', verbose=False): @@ -105,7 +105,7 @@ def main(postopp_pardir, fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model. shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation. - overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories for given task number and name. + overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions. task_name(str) : Any string task name. timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far @@ -126,7 +126,6 @@ def main(postopp_pardir, # task_folder_info is a zipped lists indexed over tasks (collaborators) # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) - col_paths = setup_fl_data(postopp_pardir=postopp_pardir, three_digit_task_num=three_digit_task_num, task_name=task_name, diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py index 68cbbbc40..1fe83a4f5 100644 --- a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -33,4 +33,4 @@ def get_valid_data_size(self): return self.valid_data_size def get_task_name(self): - return self.task_name \ No newline at end of file + return self.task_name diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 2e5df028b..78869d4c1 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -54,17 +54,14 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def train_nnunet(actual_max_num_epochs, - fl_round, - val_epoch=True, - train_epoch=True, - train_cutoff=np.inf, - val_cutoff=np.inf, +def train_nnunet(epochs, + current_epoch, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', fold='0', continue_training=True, + validation_only=False, c=False, p=plans_param, use_compressed_data=False, @@ -81,13 +78,9 @@ def train_nnunet(actual_max_num_epochs, pretrained_weights=None): """ - actual_max_num_epochs (int): Provides the number of epochs intended to be trained over the course of the whole federation (for lr scheduling) - (this needs to be held constant outside of individual calls to this function so that the lr is consistetly scheduled) - fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) - val_epoch (bool) : Will validation be performed - train_epoch (bool) : Will training run (rather than val only) task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" + validation_only: use this if you want to only run the validation c: use this if you want to continue a training p: plans identifier. Only change this if you created a custom experiment planner use_compressed_data: "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " @@ -139,6 +132,7 @@ def __init__(self, **kwargs): fold = args.fold network = args.network network_trainer = args.network_trainer + validation_only = args.validation_only plans_identifier = args.p find_lr = args.find_lr disable_postprocessing_on_folds = args.disable_postprocessing_on_folds @@ -204,7 +198,6 @@ def __init__(self, **kwargs): trainer = trainer_class( plans_file, fold, - actual_max_num_epochs=actual_max_num_epochs, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, @@ -213,30 +206,6 @@ def __init__(self, **kwargs): deterministic=deterministic, fp16=run_mixed_precision, ) - - - trainer.initialize(True) - - if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": - trainer.save_checkpoint( - join(trainer.output_folder, "model_final_checkpoint.model") - ) - print("Preparation round: Model-averaging") - return - - if find_lr: - trainer.find_lr(num_iters=self.actual_max_num_epochs) - else: - if args.continue_training: - # -c was set, continue a previous training and ignore pretrained weights - trainer.load_latest_checkpoint() - elif (not args.continue_training) and (args.pretrained_weights is not None): - # we start a new training. If pretrained_weights are set, use them - load_pretrained_weights(trainer.network, args.pretrained_weights) - else: - # new training without pretraine weights, do nothing - pass - # we want latest checkoint only (not best or any intermediate) trainer.save_final_checkpoint = ( True # whether or not to save the final checkpoint @@ -252,44 +221,61 @@ def __init__(self, **kwargs): trainer.save_latest_only = ( True # if false it will not store/overwrite _latest but separate files each ) + trainer.max_num_epochs = current_epoch + epochs + trainer.epoch = current_epoch + + # TODO: call validation separately + trainer.initialize(not validation_only) + + if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": + trainer.save_checkpoint( + join(trainer.output_folder, "model_final_checkpoint.model") + ) + print("Preparation round: Model-averaging") + return + + if find_lr: + trainer.find_lr() + else: + if not validation_only: + if args.continue_training: + # -c was set, continue a previous training and ignore pretrained weights + trainer.load_latest_checkpoint() + elif (not args.continue_training) and (args.pretrained_weights is not None): + # we start a new training. If pretrained_weights are set, use them + load_pretrained_weights(trainer.network, args.pretrained_weights) + else: + # new training without pretraine weights, do nothing + pass + + trainer.run_training() + else: + # if valbest: + # trainer.load_best_checkpoint(train=False) + # else: + # trainer.load_final_checkpoint(train=False) + trainer.load_latest_checkpoint() - trainer.max_num_epochs = fl_round + 1 - trainer.epoch = fl_round - - # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data - # point is expected to be seen epochs number of times - - num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) - num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) - - # the nnunet trainer attributes have a different naming convention than I am using - trainer.num_batches_per_epoch = num_train_batches_per_epoch - trainer.num_val_batches_per_epoch = num_val_batches_per_epoch - - batches_applied_train, \ - batches_applied_val, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics, \ - this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, \ - this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, - val_cutoff=val_cutoff, - val_epoch=val_epoch, - train_epoch=train_epoch) - - train_completed = batches_applied_train / float(num_train_batches_per_epoch) - val_completed = batches_applied_val / float(num_val_batches_per_epoch) - - return train_completed, \ - val_completed, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics, \ - this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, \ - this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 - - + trainer.network.eval() + + # if fold == "all": + # print("--> fold == 'all'") + # print("--> DONE") + # else: + # # predict validation + # trainer.validate( + # save_softmax=args.npz, + # validation_folder_name=val_folder, + # run_postprocessing_on_folds=not disable_postprocessing_on_folds, + # overwrite=args.val_disable_overwrite, + # ) + + # if network == "3d_lowres" and not args.disable_next_stage_pred: + # print("predicting segmentations for the next stage of the cascade") + # predict_next_stage( + # trainer, + # join( + # dataset_directory, + # trainer.plans["data_identifier"] + "_stage%d" % 1, + # ), + # ) diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 3191d9a57..26f184522 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? """ # TODO: Clean up imports +# TODO: ask Micah if this has to be changed (most probably no) import os import subprocess @@ -35,15 +36,12 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): def __init__(self, nnunet_task=None, config_path=None, - actual_max_num_epochs=1000, **kwargs): """Initialize. Args: - nnunet_task (str) : Task string used to identify the data and model folders - config_path(str) : Path to the configuration file used by the training and validation script. - actual_max_num_epochs (int) : Number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating - kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + config_path(str) : Path to the configuration file used by the training and validation script. + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). TODO: """ @@ -74,14 +72,6 @@ def __init__(self, ) self.config_path = config_path - self.actual_max_num_epochs=actual_max_num_epochs - - # self.task_completed is a dictionary of task to amount completed as a float in [0,1] - # Values will be dynamically updated - # TODO: Tasks are hard coded for now - self.task_completed = {'aggregated_model_validation': 1.0, - 'train': 1.0, - 'locally_tuned_model_validation': 1.0} def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): @@ -118,10 +108,11 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): # get device for correct placement of tensors device = self.device - checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load, map_location=device) + checkpoint_dict = self.load_checkpoint(map_location=device) epoch = checkpoint_dict['epoch'] new_state = {} # grabbing keys from the checkpoint state dict, poping from the tensor_dict + # Brandon DEBUGGING seen_keys = [] for k in checkpoint_dict['state_dict']: if k not in seen_keys: @@ -143,113 +134,91 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): return epoch - def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time=np.inf, train_cutoff_time=np.inf, train_completion_dampener=0.0, **kwargs): + def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - self.logger.info(f"Training for round:{round_num}") - train_completed, \ - val_completed, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics, \ - this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, \ - this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, - fl_round=round_num, - train_cutoff=train_cutoff_time, - val_cutoff = val_cutoff_time, - task=self.data_loader.get_task_name(), - val_epoch=True, - train_epoch=True) - - # dampen the train_completion + current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + # 2. Train function existing externally + # Some todo inside function below + # TODO: test for off-by-one error + # TODO: we need to disable validation if possible, and separately call validation + + # FIXME: we need to understand how to use round_num instead of current_epoch + # this will matter in straggler handling cases + # TODO: Should we put this in a separate process? + train_nnunet(epochs=epochs, current_epoch=current_epoch, task=self.data_loader.get_task_name()) + + # 3. Load metrics from checkpoint + (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] + # these metrics are appended to the checkopint each epoch, so we select the most recent epoch + metrics = {'train_loss': all_tr_losses[-1], + 'val_eval': all_val_eval_metrics[-1]} + + return self.convert_results_to_tensorkeys(col_name, round_num, metrics) + + + + def validate(self, col_name, round_num, input_tensor_dict, **kwargs): """ - values in range: (0, 1] with values near 0.0 making all train_completion rates shift nearer to 1.0, thus making the - trained model update weighting during aggregation stay closer to the plain data size weighting - specifically, update_weight = train_data_size / train_completed**train_completion_dampener + Run the trained model on validation data; report results. + + Parameters + ---------- + input_tensor_dict : either the last aggregated or locally trained model + + Returns + ------- + output_tensor_dict : {TensorKey: nparray} (these correspond to acc, + precision, f1_score, etc.) """ - train_completed = train_completed**train_completion_dampener - # update amount of task completed - self.task_completed['train'] = train_completed - self.task_completed['locally_tuned_model_validation'] = val_completed + raise NotImplementedError() - # 3. Prepare metrics - metrics = {'train_loss': this_ave_train_loss} + """ - TBD - for now commenting out - global_tensor_dict, local_tensor_dict = self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) + self.rebuild_model(round_num, input_tensor_dict, validation=True) - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - - return global_tensor_dict, local_tensor_dict - + # 1. Save model in native format + self.save_native(self.mlcube_model_in_path) - def validate(self, col_name, round_num, input_tensor_dict, val_cutoff_time=np.inf, from_checkpoint=False, **kwargs): - # TODO: Figure out the right name to use for this method and the default assigner - """Perform validation.""" - - if not from_checkpoint: - self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) - # 1. Insert tensor_dict info into checkpoint - self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - self.logger.info(f"Validating for round:{round_num}") - # 2. Train/val function existing externally - # Some todo inside function below - train_completed, \ - val_completed, \ - this_ave_train_loss, \ - this_ave_val_loss, \ - this_val_eval_metrics, \ - this_val_eval_metrics_C1, \ - this_val_eval_metrics_C2, \ - this_val_eval_metrics_C3, \ - this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, - fl_round=round_num, - train_cutoff=0, - val_cutoff = val_cutoff_time, - task=self.data_loader.get_task_name(), - val_epoch=True, - train_epoch=False) - # double check - if train_completed != 0.0: - raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") - - # update amount of task completed - self.task_completed['aggregated_model_validation'] = val_completed - - self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") - - - # 3. Prepare metrics - metrics = {'val_eval': this_val_eval_metrics, - 'val_eval_C1': this_val_eval_metrics_C1, - 'val_eval_C2': this_val_eval_metrics_C2, - 'val_eval_C3': this_val_eval_metrics_C3, - 'val_eval_C4': this_val_eval_metrics_C4} + # 2. Call MLCube validate task + platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) + task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') + proc = subprocess.run(["mlcube_docker", + "run", + "--mlcube={}".format(self.mlcube_dir), + "--platform={}".format(platform_yaml), + "--task={}".format(task_yaml)]) + + # 3. Load any metrics + metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) + + # set the validation data size + sample_count = int(metrics.pop(self.evaluation_sample_count_key)) + self.data_loader.set_valid_data_size(sample_count) + + # 4. Convert to tensorkeys + + origin = col_name + suffix = 'validate' + if kwargs['apply'] == 'local': + suffix += '_local' else: - checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) - - all_tr_losses, \ - all_val_losses, \ - all_val_losses_tr_mode, \ - all_val_eval_metrics, \ - all_val_eval_metrics_C1, \ - all_val_eval_metrics_C2, \ - all_val_eval_metrics_C3, \ - all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] - # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after - metrics = {'val_eval': all_val_eval_metrics[-1], - 'val_eval_C1': all_val_eval_metrics_C1[-1], - 'val_eval_C2': all_val_eval_metrics_C2[-1], - 'val_eval_C3': all_val_eval_metrics_C3[-1], - 'val_eval_C4': all_val_eval_metrics_C4[-1]} - - return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) + suffix += '_agg' + tags = ('metric', suffix) + output_tensor_dict = { + TensorKey( + metric_name, origin, round_num, True, tags + ): np.array(metrics[metric_name]) + for metric_name in metrics + } + + return output_tensor_dict, {} + + """ def load_metrics(self, filepath): @@ -261,38 +230,4 @@ def load_metrics(self, filepath): with open(filepath) as json_file: metrics = json.load(json_file) return metrics - """ - - - def get_train_data_size(self, task_name=None): - """Get the number of training examples. - - It will be used for weighted averaging in aggregation. - This overrides the parent class method, - allowing dynamic weighting by storing recent appropriate weights in class attributes. - - Returns: - int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema - """ - if not task_name: - return self.data_loader.get_train_data_size() - else: - # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] - return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_train_data_size())) - - - def get_valid_data_size(self, task_name=None): - """Get the number of training examples. - - It will be used for weighted averaging in aggregation. - This overrides the parent class method, - allowing dynamic weighting by storing recent appropriate weights in class attributes. - - Returns: - int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema - """ - if not task_name: - return self.data_loader.get_valid_data_size() - else: - # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] - return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_valid_data_size())) + """ \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index a7fbd2056..6ab7851b9 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? """ # TODO: Clean up imports @@ -82,10 +82,12 @@ def __init__(self, self.replace_checkpoint(self.checkpoint_path_initial) - def load_checkpoint(self, checkpoint_path, map_location=None): + def load_checkpoint(self, checkpoint_path=None, map_location=None): """ Function used to load checkpoint from disk. """ + if not checkpoint_path: + checkpoint_path = self.checkpoint_path_load checkpoint_dict = torch.load(checkpoint_path, map_location=map_location) return checkpoint_dict @@ -122,7 +124,7 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): return self.required_tensorkeys_for_function[func_name] def reset_opt_vars(self): - current_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + current_checkpoint_dict = self.load_checkpoint() initial_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_initial) derived_opt_state_dict = self._get_optimizer_state(checkpoint_dict=initial_checkpoint_dict) self._set_optimizer_state(derived_opt_state_dict=derived_opt_state_dict, @@ -170,7 +172,7 @@ def read_tensors_from_checkpoint(self, with_opt_vars): dict: Tensor dictionary {**dict, **optimizer_dict} """ - checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + checkpoint_dict = self.load_checkpoint() state = to_cpu_numpy(checkpoint_dict['state_dict']) if with_opt_vars: opt_state = self._get_optimizer_state(checkpoint_dict=checkpoint_dict) @@ -253,9 +255,7 @@ def _read_opt_state_from_checkpoint(self, checkpoint_dict): return derived_opt_state_dict - def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_model): - # insert_model determined whether or not to include the model in the return dictionaries - + def convert_results_to_tensorkeys(self, col_name, round_num, metrics): # 5. Convert to tensorkeys # output metric tensors (scalar) @@ -268,14 +268,11 @@ def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_mod metrics[metric_name] ) for metric_name in metrics} - if insert_model: - # output model tensors (Doesn't include TensorKey) - output_model_dict = self.get_tensor_dict(with_opt_vars=True) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, - tensor_dict=output_model_dict, - **self.tensor_dict_split_fn_kwargs) - else: - global_model_dict, local_model_dict = {}, {} + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict(with_opt_vars=True) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, + tensor_dict=output_model_dict, + **self.tensor_dict_split_fn_kwargs) # create global tensorkeys global_tensorkey_model_dict = { From 68577a39370a0316804dd2c76acc28511a06ccb4 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 31 Oct 2024 23:08:33 +0000 Subject: [PATCH 231/242] use fixed hashes for openfl and nnunet installations --- examples/fl_post/fl/build.sh | 2 +- examples/fl_post/fl/project/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index 65b2c633f..b4521e9a1 100755 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout d0f6df8ea91e0eaaeabf0691caf0286162df5bd7 + git checkout b5e26ac33935966800b6a5b61e85b823cc68c4da docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl diff --git a/examples/fl_post/fl/project/requirements.txt b/examples/fl_post/fl/project/requirements.txt index 8f03308f9..569eaf802 100644 --- a/examples/fl_post/fl/project/requirements.txt +++ b/examples/fl_post/fl/project/requirements.txt @@ -1,4 +1,4 @@ onnx==1.13.0 typer==0.9.0 -git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@main#egg=nnunet +git+https://github.com/brandon-edwards/nnUNet_v1.7.1_local.git@077f4852c81da1d1e1141547cbd09ac68ade4b5b#egg=nnunet numpy==1.26.4 From 4b0d1b041c5af86eb5a17ff581fd3a655ee1b152 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 1 Nov 2024 01:56:27 +0000 Subject: [PATCH 232/242] update openfl commit --- examples/fl_post/fl/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/build.sh b/examples/fl_post/fl/build.sh index b4521e9a1..3b9135c3c 100755 --- a/examples/fl_post/fl/build.sh +++ b/examples/fl_post/fl/build.sh @@ -8,7 +8,7 @@ BUILD_BASE="${BUILD_BASE:-false}" if ${BUILD_BASE}; then git clone https://github.com/hasan7n/openfl.git cd openfl - git checkout b5e26ac33935966800b6a5b61e85b823cc68c4da + git checkout 3ed63c9ed24311b4ad581da5b859b04853c08375 docker build -t local/openfl:local -f openfl-docker/Dockerfile.base . cd .. rm -rf openfl From 1330f926e2e3d27fc9f9d5c582995e058bc65d7e Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 17:10:03 -0800 Subject: [PATCH 233/242] changing the number of val batches and training batches to apply locally to 50 and 250 respectively. Timeouts will still apply to stop early if it is taking too long. Amount of training completed will be computed off of these new maxes, i.e. training_completed of 1.0 will mean all of the 250 batches were trained. With this change we will want model updates to be counted according to local data size, therefore it is important that the training config set the train_completion_dampener at the value: 0.0. --- examples/fl_post/fl/project/src/nnunet_v1.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 2e5df028b..b162eb1e0 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -256,11 +256,10 @@ def __init__(self, **kwargs): trainer.max_num_epochs = fl_round + 1 trainer.epoch = fl_round - # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data - # point is expected to be seen epochs number of times - - num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) - num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) + # STAYing WITH NNUNET CONVENTION OF 50 AND 250 VAL AND TRAIN BATCHES RESPECTIVELY + # Note: This convention makes sense in combination with a train_completion_dampener of 0.0 + num_val_batches_per_epoch = 50 + num_train_batches_per_epoch = 250 # the nnunet trainer attributes have a different naming convention than I am using trainer.num_batches_per_epoch = num_train_batches_per_epoch From e706a71d97861c4ab153ca695d4e287457b6686a Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 17:15:57 -0800 Subject: [PATCH 234/242] setting the train_completion_dampener to 0.0 --- examples/fl_post/fl/mlcube/workspace/training_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index d12643dd6..7c118424c 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -29,7 +29,7 @@ aggregator : admin_settable: True min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) max: 1.0 # leaves completion rates as is - value: 1.0 + value: 0.0 aggregated_model_validation: val_cutoff_time: From c0c68034e0e434ac84723b34006757c0e25ceec7 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 17:26:43 -0800 Subject: [PATCH 235/242] Evan had the following order specified in his 'order' file given to us related to his new initial model: order = t1 t2 flair t1c --- examples/fl_post/fl/project/nnunet_data_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index db3894e9d..df176ac67 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -14,8 +14,8 @@ num_to_modality = {'_0000': '_brain_t1n.nii.gz', '_0001': '_brain_t2w.nii.gz', - '_0002': '_brain_t1c.nii.gz', - '_0003': '_brain_t2f.nii.gz'} + '_0002': '_brain_t2f.nii.gz', + '_0003': '_brain_t1c.nii.gz'} def get_subdirs(parent_directory): subjects = os.listdir(parent_directory) From 915bb61418d3f226728e409aa5870f2d4a921e4b Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 17:51:41 -0800 Subject: [PATCH 236/242] making train/val split random operations depend on a seed, default seed provided --- .../fl_post/fl/project/nnunet_data_setup.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index df176ac67..b8e5b18ba 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -115,14 +115,16 @@ def doublecheck_postopp_pardir(postopp_pardir, verbose=False): raise ValueError(f"'labels' must be a subdirectory of postopp_src_pardir:{postopp_pardir}, but it is not.") -def split_by_subject(subject_to_timestamps, percent_train, verbose=False): +def split_by_subject(subject_to_timestamps, percent_train, split_seed, verbose=False): """ NOTE: An attempt is made to put percent_train of the total subjects into train (as opposed to val) regardless of how many timestamps there are for each subject. No subject is allowed to have samples in both train and val. """ subjects = list(subject_to_timestamps.keys()) - np.random.shuffle(subjects) + # create a random number generator with our seed + rng = np.random.default_rng(split_seed) + rng.shuffle(subjects) train_cutoff = int(len(subjects) * percent_train) @@ -132,7 +134,7 @@ def split_by_subject(subject_to_timestamps, percent_train, verbose=False): return train_subject_to_timestamps, val_subject_to_timestamps -def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30, verbose=False): +def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30, split_seed, verbose=False): """ NOTE: An attempt is made to put percent_train of the subject timestamp combinations into train (as opposed to val) regardless of what that does to the subject ratios. No subject is allowed to have samples in both train and val. @@ -143,9 +145,11 @@ def percent_train_for_split(train_subjects, grand_total): sub_total += subject_counts[subject] return sub_total/grand_total - def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): + def shuffle_and_cut(subject_counts, grand_total, percent_train, seed, verbose=False): subjects = list(subject_counts.keys()) - np.random.shuffle(subjects) + # create a random number generator with our seed + rng = np.random.default_rng(seed) + rng.shuffle(subjects) for idx in range(2,len(subjects)+1): train_subjects = subjects[:idx-1] val_subjects = subjects[idx-1:] @@ -172,8 +176,9 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): best_percent_train = percent_train_for_split(train_subjects=best_train_subjects, grand_total=grand_total) # random shuffle times in order to find the closest we can get to honoring the percent_train requirement (train and val both need to be non-empty) - for _ in range(random_tries): - train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, verbose=verbose) + for _try in range(random_tries): + seed = split_seed + _try + train_subjects, val_subjects, percent_train_estimate = shuffle_and_cut(subject_counts=subject_counts, grand_total=grand_total, percent_train=percent_train, seed=seed, verbose=verbose) if abs(percent_train_estimate - percent_train) < abs(best_percent_train - percent_train): best_train_subjects = train_subjects best_val_subjects = val_subjects @@ -185,16 +190,16 @@ def shuffle_and_cut(subject_counts, grand_total, percent_train, verbose=False): return train_subject_to_timestamps, val_subject_to_timestamps -def write_splits_file(subject_to_timestamps, percent_train, split_logic, fold, task, splits_fname='splits_final.pkl', verbose=False): +def write_splits_file(subject_to_timestamps, percent_train, split_logic, split_seed, fold, task, splits_fname='splits_final.pkl', verbose=False): # double check we are in the right folder to modify the splits file splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', splits_fname) POSTOPP_splits_fpath = os.path.join(os.environ['nnUNet_preprocessed'], f'{task}', 'POSTOPP_BACKUP_' + splits_fname) # now split if split_logic == 'by_subject': - train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose) + train_subject_to_timestamps, val_subject_to_timestamps = split_by_subject(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose) elif split_logic == 'by_subject_time_pair': - train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, verbose=verbose) + train_subject_to_timestamps, val_subject_to_timestamps = split_by_timed_subjects(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, split_seed=split_seed, verbose=verbose) else: raise ValueError(f"Split logic of 'by_subject' and 'by_subject_time_pair' are the only ones supported, whereas a split_logic value of {split_logic} was provided.") @@ -235,6 +240,7 @@ def setup_fl_data(postopp_pardir, init_model_info_path, cuda_device, overwrite_nnunet_datadirs, + split_seed=7777777, plans_path=None, verbose=False): """ @@ -297,10 +303,6 @@ def setup_fl_data(postopp_pardir, three_digit_task_num(str): Should start with '5'. task_name(str) : Any string task name. - percent_train(float) : what percent of data is put into the training data split (rest to val) - split_logic(str) : Determines how train/val split is performed - timestamp_selection(str) : Indicates how to determine the timestamp to pick - for each subject ID at the source: 'latest', 'earliest', and 'all' are the only ones supported so far network(str) : Which network is being used for NNUnet network_trainer(str) : Which network trainer class is being used for NNUnet local_plans_identifier(str) : Used in the plans file name for a collaborator that will be performing local training to produce an initial model @@ -309,6 +311,7 @@ def setup_fl_data(postopp_pardir, init_model_info_path(str) : Path to the initial model info (pkl) file cuda_device(str) : Device to perform training ('cpu' or 'cuda') overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions. + split_seed (int) : Seed used for the random number generator used within the split logic plans_path(str) : Path to the training plans (pkl) percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since all timestamps associated with the same subject need to land exclusively in either train or val. @@ -363,7 +366,8 @@ def setup_fl_data(postopp_pardir, # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val write_splits_file(subject_to_timestamps=subject_to_timestamps, percent_train=percent_train, - split_logic=split_logic, + split_logic=split_logic, + split_seed=split_seed, fold=fold, task=task, verbose=verbose) @@ -414,4 +418,5 @@ def setup_fl_data(postopp_pardir, print(f"plans_path: {col_paths['plans_path']}") print(f"\n### ### ### ### ### ### ###\n") - return col_paths \ No newline at end of file + return col_paths + \ No newline at end of file From d05dfe3ec46299e4f01d4fc71c9453ea560c8a56 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 18:12:48 -0800 Subject: [PATCH 237/242] non default argument had been following default argument --- examples/fl_post/fl/project/nnunet_data_setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index b8e5b18ba..9191a3de6 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -134,7 +134,7 @@ def split_by_subject(subject_to_timestamps, percent_train, split_seed, verbose=F return train_subject_to_timestamps, val_subject_to_timestamps -def split_by_timed_subjects(subject_to_timestamps, percent_train, random_tries=30, split_seed, verbose=False): +def split_by_timed_subjects(subject_to_timestamps, percent_train, split_seed, random_tries=30, verbose=False): """ NOTE: An attempt is made to put percent_train of the subject timestamp combinations into train (as opposed to val) regardless of what that does to the subject ratios. No subject is allowed to have samples in both train and val. @@ -419,4 +419,3 @@ def setup_fl_data(postopp_pardir, print(f"\n### ### ### ### ### ### ###\n") return col_paths - \ No newline at end of file From 0d9890eb99790b2d3c7e2af61486181732d30f6f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 18:37:51 -0800 Subject: [PATCH 238/242] testing a training config --- .../fl_post/fl/mlcube/workspace/training_config.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 7c118424c..1e8e05e45 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -5,7 +5,7 @@ aggregator : init_state_path : save/fl_post_two_init.pbuf best_state_path : save/fl_post_two_best.pbuf last_state_path : save/fl_post_two_last.pbuf - rounds_to_train : &rounds_to_train 2 + rounds_to_train : &rounds_to_train 40 admins_endpoints_mapping: testfladmin@example.com: - GetExperimentStatus @@ -19,12 +19,12 @@ aggregator : admin_settable: True min: 10 # 10 seconds max: 86400 # one day - value: 86400 # one day + value: 20 # one day val_cutoff_time: admin_settable: True min: 10 # 10 seconds max: 86400 # one day - value: 86400 # one day + value: 20 # one day train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0 admin_settable: True min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) @@ -36,7 +36,7 @@ aggregator : admin_settable: True min: 10 # 10 seconds max: 86400 # one day - value: 86400 # one day + value: 20 # one day collaborator : @@ -116,4 +116,4 @@ straggler_handling_policy : template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling settings : straggler_cutoff_time : 600 - minimum_reporting : 5 \ No newline at end of file + minimum_reporting : 5 From bbfc71f12baf6e5eaacecc3813ad0fe54097f6f9 Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 18:53:24 -0800 Subject: [PATCH 239/242] putting capability to set seed at top level --- examples/fl_post/fl/project/nnunet_setup.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 81f6787d8..e9b23f814 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -19,7 +19,8 @@ def main(postopp_pardir, three_digit_task_num, task_name, percent_train=0.8, - split_logic='by_subject_time_pair', + split_logic='by_subject_time_pair', + split_seed=7777777, network='3d_fullres', network_trainer='nnUNetTrainerV2', fold='0', @@ -100,6 +101,7 @@ def main(postopp_pardir, percent_train(float) : The percentage of samples to split into the train portion for the fold specified below (NNUnet makes its own folds but we overwrite all with None except the fold indicated below and put in our own split instead determined by a hard coded split logic default) split_logic(str) : Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair' (see inner function docstring) + split_seed(int) : base rng seed used in split logic network(str) : NNUnet network to be used network_trainer(str) : NNUnet network trainer to be used fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' @@ -132,6 +134,7 @@ def main(postopp_pardir, task_name=task_name, percent_train=percent_train, split_logic=split_logic, + split_seed=split_seed, fold=fold, timestamp_selection=timestamp_selection, network=network, @@ -187,6 +190,11 @@ def main(postopp_pardir, type=str, default='by_subject_time_pair', help="Determines how the percent_train is computed. Choices are: 'by_subject' and 'by_subject_time_pair' (see inner function docstring)") + argparser.add_argument( + '--split_seed', + type=int, + default=7777777, + help="base rng seed used in split logic") argparser.add_argument( '--network', type=str, From dc2b9ea5f69e4a6764e5fb6a30abb434aa50e21f Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 21:01:05 -0800 Subject: [PATCH 240/242] docstring change --- examples/fl_post/fl/project/nnunet_data_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 9191a3de6..73b507466 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -311,7 +311,7 @@ def setup_fl_data(postopp_pardir, init_model_info_path(str) : Path to the initial model info (pkl) file cuda_device(str) : Device to perform training ('cpu' or 'cuda') overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions. - split_seed (int) : Seed used for the random number generator used within the split logic + split_seed (int) : Base seed for seeds used for the random number generator within the split logic plans_path(str) : Path to the training plans (pkl) percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since all timestamps associated with the same subject need to land exclusively in either train or val. From c53210c9857c6966f468a7dac1321cb0a156b6de Mon Sep 17 00:00:00 2001 From: "Edwards, Brandon" Date: Thu, 7 Nov 2024 21:02:24 -0800 Subject: [PATCH 241/242] docstring --- examples/fl_post/fl/project/nnunet_data_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 73b507466..6db3663d6 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -311,7 +311,7 @@ def setup_fl_data(postopp_pardir, init_model_info_path(str) : Path to the initial model info (pkl) file cuda_device(str) : Device to perform training ('cpu' or 'cuda') overwrite_nnunet_datadirs(bool) : Allows for overwriting past instances of NNUnet data directories using the task numbers from first_three_digit_task_num to that plus one less than number of insitutions. - split_seed (int) : Base seed for seeds used for the random number generator within the split logic + split_seed (int) : Base seed for seeds used for the random number generators within the split logic plans_path(str) : Path to the training plans (pkl) percent_train(float) : What percentage of timestamped subjects to attempt dedicate to train versus val. Will be only approximately acheived in general since all timestamps associated with the same subject need to land exclusively in either train or val. From e35e722ffbfd0952a781ecd9c656e85cf491d130 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Fri, 8 Nov 2024 18:21:28 +0100 Subject: [PATCH 242/242] Revert "Revert "Merge pull request #7 from hasan7n/be_enable_partial_epochs"" This reverts commit 11fd6e766d06494276194141803e78f3d1769981. --- .../fl/mlcube/workspace/training_config.yaml | 154 ++++++++---- .../fl_post/fl/project/nnunet_data_setup.py | 16 +- .../fl_post/fl/project/nnunet_model_setup.py | 61 +---- examples/fl_post/fl/project/nnunet_setup.py | 5 +- .../fl/project/src/nnunet_dummy_dataloader.py | 2 +- examples/fl_post/fl/project/src/nnunet_v1.py | 138 ++++++----- .../fl_post/fl/project/src/runner_nnunetv1.py | 223 +++++++++++------- .../fl_post/fl/project/src/runner_pt_chkpt.py | 27 ++- 8 files changed, 354 insertions(+), 272 deletions(-) diff --git a/examples/fl_post/fl/mlcube/workspace/training_config.yaml b/examples/fl_post/fl/mlcube/workspace/training_config.yaml index 8a73fe439..0173431b6 100644 --- a/examples/fl_post/fl/mlcube/workspace/training_config.yaml +++ b/examples/fl_post/fl/mlcube/workspace/training_config.yaml @@ -1,69 +1,119 @@ -aggregator: - defaults: plan/defaults/aggregator.yaml - template: openfl.component.Aggregator - settings: - init_state_path: save/fl_post_two_init.pbuf - best_state_path: save/fl_post_two_best.pbuf - last_state_path: save/fl_post_two_last.pbuf - rounds_to_train: 2 +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/fl_post_two_init.pbuf + best_state_path : save/fl_post_two_best.pbuf + last_state_path : save/fl_post_two_last.pbuf + rounds_to_train : &rounds_to_train 20 admins_endpoints_mapping: col1@example.com: - GetExperimentStatus - SetStragglerCuttoffTime + - SetDynamicTaskArg + - GetDynamicTaskArg -collaborator: - defaults: plan/defaults/collaborator.yaml - template: openfl.component.Collaborator - settings: - delta_updates: false - opt_treatment: CONTINUE_LOCAL + dynamictaskargs: &dynamictaskargs + train: + train_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 300 # one day + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day + train_completion_dampener: # train_completed -> (train_completed)**(train_completion_dampener) NOTE: Value close to zero zero shifts non 0.0 completion rates much closer to 1.0 + admin_settable: True + min: -1.0 # inverts train_comleted, so this would be a way to have checkpoint_weighting = train_completed * data_size (as opposed to data_size / train_completed) + max: 1.0 # leaves completion rates as is + value: 1.0 -data_loader: - defaults: plan/defaults/data_loader.yaml - template: src.nnunet_dummy_dataloader.NNUNetDummyDataLoader - settings: - p_train: 0.8 + aggregated_model_validation: + val_cutoff_time: + admin_settable: True + min: 10 # 10 seconds + max: 86400 # one day + value: 86400 # one day -# TODO: make checkpoint-only truly generic and create the task runner within src -task_runner: - defaults: plan/defaults/task_runner.yaml - template: src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner - settings: - device: cuda - gpu_num_string: "0" - nnunet_task: Task537_FLPost -network: - defaults: plan/defaults/network.yaml +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : CONTINUE_LOCAL + dynamictaskargs: *dynamictaskargs + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.nnunet_dummy_dataloader.NNUNetDummyDataLoader + settings : + p_train : 0.8 + +# TODO: make checkpoint-only truly generic and create the task runner within src +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.runner_nnunetv1.PyTorchNNUNetCheckpointTaskRunner + settings : + device : cuda + gpu_num_string : '0' + nnunet_task : Task537_FLPost + actual_max_num_epochs : *rounds_to_train + +network : + defaults : plan/defaults/network.yaml settings: {} -assigner: - defaults: plan/defaults/assigner.yaml - template: openfl.component.assigner.DynamicRandomGroupedAssigner - settings: - task_groups: - - name: train_and_validate - percentage: 1.0 - tasks: - # - aggregated_model_validation +assigner : + defaults : plan/defaults/assigner.yaml + template : openfl.component.assigner.DynamicRandomGroupedAssigner + settings : + task_groups : + - name : train_and_validate + percentage : 1.0 + tasks : + - aggregated_model_validation - train - # - locally_tuned_model_validation + - locally_tuned_model_validation -tasks: - defaults: plan/defaults/tasks_torch.yaml +tasks : + defaults : plan/defaults/tasks_torch.yaml + aggregated_model_validation: + function : validate + kwargs : + metrics : + - val_eval + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 + apply : global train: - function: train - kwargs: - metrics: + function : train + kwargs : + metrics : - train_loss + epochs : 1 + locally_tuned_model_validation: + function : validate + kwargs : + metrics : - val_eval - epochs: 1 + - val_eval_C1 + - val_eval_C2 + - val_eval_C3 + - val_eval_C4 + apply : local + from_checkpoint: true -compression_pipeline: - defaults: plan/defaults/compression_pipeline.yaml +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml -straggler_handling_policy: - template: openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling - settings: - straggler_cutoff_time: 600 - minimum_reporting: 2 +straggler_handling_policy : + template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + settings : + straggler_cutoff_time : 1200 + minimum_reporting : 2 \ No newline at end of file diff --git a/examples/fl_post/fl/project/nnunet_data_setup.py b/examples/fl_post/fl/project/nnunet_data_setup.py index 3f24c9515..db3894e9d 100644 --- a/examples/fl_post/fl/project/nnunet_data_setup.py +++ b/examples/fl_post/fl/project/nnunet_data_setup.py @@ -246,11 +246,8 @@ def setup_fl_data(postopp_pardir, should be run using a virtual environment that has nnunet version 1 installed. args: - postopp_pardirs(list of str) : Parent directories for postopp data. The length of the list should either be - equal to num_insitutions, or one. If the length of the list is one and num_insitutions is not one, - the samples within that single directory will be used to create num_insititutions shards. - If the length of this list is equal to num_insitutions, the shards are defined by the samples within each string path. - Either way, all string paths within this list should piont to folders that have 'data' and 'labels' subdirectories with structure: + postopp_pardir(str) : Parent directory for postopp data. + This directory should have 'data' and 'labels' subdirectories, with structure: ├── data │ ├── AAAC_0 │ │ ├── 2008.03.30 @@ -298,7 +295,7 @@ def setup_fl_data(postopp_pardir, │ └── AAAC_extra_2008.12.10_final_seg.nii.gz └── report.yaml - three_digit_task_num(str): Should start with '5'. If num_institutions == N (see below), all N task numbers starting with this number will be used. + three_digit_task_num(str): Should start with '5'. task_name(str) : Any string task name. percent_train(float) : what percent of data is put into the training data split (rest to val) split_logic(str) : Determines how train/val split is performed @@ -336,7 +333,6 @@ def setup_fl_data(postopp_pardir, # Track the subjects and timestamps for each shard subject_to_timestamps = {} - print(f"\n######### CREATING SYMLINKS TO POSTOPP DATA #########\n") for postopp_subject_dir in all_subjects: subject_to_timestamps[postopp_subject_dir] = symlink_one_subject(postopp_subject_dir=postopp_subject_dir, postopp_data_dirpath=postopp_data_dirpath, @@ -356,12 +352,12 @@ def setup_fl_data(postopp_pardir, # Now call the os process to preprocess the data print(f"\n######### OS CALL TO PREPROCESS DATA #########\n") if plans_path: - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl2d", "None", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "-pl3d", "ExperimentPlanner3D_v21_Pretrained", "-pl2d", "None", "-overwrite_plans", f"{plans_path}", "-overwrite_plans_identifier", "POSTOPP", "-no_pp"]) plans_identifier_for_model_writing = shared_plans_identifier else: # this is a preliminary data setup, which will be passed over to the pretrained plan similar to above after we perform training on this plan - subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity"]) + subprocess.run(["nnUNet_plan_and_preprocess", "-t", f"{three_digit_task_num}", "--verify_dataset_integrity", "-pl2d", "None"]) plans_identifier_for_model_writing = local_plans_identifier # Now compute our own stratified splits file, keeping all timestampts for a given subject exclusively in either train or val diff --git a/examples/fl_post/fl/project/nnunet_model_setup.py b/examples/fl_post/fl/project/nnunet_model_setup.py index 4ebd1f9e7..a647a2f44 100644 --- a/examples/fl_post/fl/project/nnunet_model_setup.py +++ b/examples/fl_post/fl/project/nnunet_model_setup.py @@ -7,12 +7,13 @@ def train_on_task(task, network, network_trainer, fold, cuda_device, plans_identifier, continue_training=False, current_epoch=0): os.environ['CUDA_VISIBLE_DEVICES']=cuda_device - print(f"###########\nStarting training for task: {task}\n") - train_nnunet(epochs=1, - current_epoch = current_epoch, - network = network, + print(f"###########\nStarting training a single epoch for task: {task}\n") + # Function below is now hard coded for a single epoch of training. + train_nnunet(actual_max_num_epochs=1000, + fl_round=current_epoch, + network=network, task=task, - network_trainer = network_trainer, + network_trainer=network_trainer, fold=fold, continue_training=continue_training, p=plans_identifier) @@ -60,61 +61,13 @@ def delete_2d_data(network, task, plans_identifier): print(f"\n###########\nDeleting 2D data directory at: {data_dir_2d} \n##############\n") shutil.rmtree(data_dir_2d) -""" -def normalize_architecture(reference_plan_path, target_plan_path): - - # Take the plan file from reference_plan_path and use its contents to copy architecture into target_plan_path - # NOTE: Here we perform some checks and protection steps so that our assumptions if not correct will more - likely leed to an exception. - - - - assert_same_keys = ['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'num_classes', 'all_classes', 'base_num_features', - 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', - 'transpose_backward', 'preprocessor_name', 'conv_per_stage', 'data_identifier'] - copy_over_keys = ['plans_per_stage'] - nullify_keys = ['original_spacings', 'original_sizes'] - leave_alone_keys = ['list_of_npz_files', 'preprocessed_data_folder', 'dataset_properties'] - - - # check I got all keys here - assert set(copy_over_keys).union(set(assert_same_keys)).union(set(nullify_keys)).union(set(leave_alone_keys)) == set(['num_stages', 'num_modalities', 'modalities', 'normalization_schemes', 'dataset_properties', 'list_of_npz_files', 'original_spacings', 'original_sizes', 'preprocessed_data_folder', 'num_classes', 'all_classes', 'base_num_features', 'use_mask_for_norm', 'keep_only_largest_region', 'min_region_size_per_class', 'min_size_per_class', 'transpose_forward', 'transpose_backward', 'data_identifier', 'plans_per_stage', 'preprocessor_name', 'conv_per_stage']) - - def get_pickle_obj(path): - with open(path, 'rb') as _file: - plan= pkl.load(_file) - return plan - - def write_pickled_obj(obj, path): - with open(path, 'wb') as _file: - pkl.dump(obj, _file) - - reference_plan = get_pickle_obj(path=reference_plan_path) - target_plan = get_pickle_obj(path=target_plan_path) - - for key in assert_same_keys: - if reference_plan[key] != target_plan[key]: - raise ValueError(f"normalize architecture failed since the reference and target plans differed in at least key: {key}") - for key in copy_over_keys: - target_plan[key] = reference_plan[key] - for key in nullify_keys: - target_plan[key] = None - # leave alone keys are left alone :) - - # write back to target plan - write_pickled_obj(obj=target_plan, path=target_plan_path) -""" def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, fold, init_model_path, init_model_info_path, plans_path, cuda_device='0'): """ Note that plans_identifier here is designated from fl_setup.py and is an alternative to the default one due to overwriting of the local plans by a globally distributed one """ - # Remove 2D data and 2D data info if appropriate - if network != '2d': - delete_2d_data(network=network, task=task, plans_identifier=plans_identifier) - # get or create architecture info model_folder = get_model_folder(network=network, @@ -141,7 +94,7 @@ def trim_data_and_setup_model(task, network, network_trainer, plans_identifier, shutil.copyfile(src=col_paths['final_model_path'],dst=col_paths['initial_model_path']) shutil.copyfile(src=col_paths['final_model_info_path'],dst=col_paths['initial_model_info_path']) else: - print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\ncol_paths were: {col_paths}\n\n") + print(f"\n######### WRITING MODEL, MODEL INFO, and PLANS #########\n") shutil.copy(src=plans_path,dst=col_paths['plans_path']) shutil.copyfile(src=init_model_path,dst=col_paths['initial_model_path']) shutil.copyfile(src=init_model_info_path,dst=col_paths['initial_model_info_path']) diff --git a/examples/fl_post/fl/project/nnunet_setup.py b/examples/fl_post/fl/project/nnunet_setup.py index 0106b8f95..81f6787d8 100644 --- a/examples/fl_post/fl/project/nnunet_setup.py +++ b/examples/fl_post/fl/project/nnunet_setup.py @@ -28,7 +28,7 @@ def main(postopp_pardir, plans_path=None, local_plans_identifier=local_plans_identifier, shared_plans_identifier=shared_plans_identifier, - overwrite_nnunet_datadirs=False, + overwrite_nnunet_datadirs=True, timestamp_selection='all', cuda_device='0', verbose=False): @@ -105,7 +105,7 @@ def main(postopp_pardir, fold(str) : Fold to train on, can be a sting indicating an int, or can be 'all' local_plans_identifier(str) : Used in the plans file naming for collaborators that will be performing local training to produce a pretrained model. shared_plans_identifier(str) : Used in the plans file naming for the shared plan distributed across the federation. - overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories with task numbers from first_three_digit_task_num to that plus one les than number of insitutions. + overwrite_nnunet_datadirs(str) : Allows overwriting NNUnet directories for given task number and name. task_name(str) : Any string task name. timestamp_selection(str) : Indicates how to determine the timestamp to pick. Only 'earliest', 'latest', or 'all' are supported. for each subject ID at the source: 'latest' and 'earliest' are the only ones supported so far @@ -126,6 +126,7 @@ def main(postopp_pardir, # task_folder_info is a zipped lists indexed over tasks (collaborators) # zip(task_nums, tasks, nnunet_dst_pardirs, nnunet_images_train_pardirs, nnunet_labels_train_pardirs) + col_paths = setup_fl_data(postopp_pardir=postopp_pardir, three_digit_task_num=three_digit_task_num, task_name=task_name, diff --git a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py index 1fe83a4f5..68cbbbc40 100644 --- a/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py +++ b/examples/fl_post/fl/project/src/nnunet_dummy_dataloader.py @@ -33,4 +33,4 @@ def get_valid_data_size(self): return self.valid_data_size def get_task_name(self): - return self.task_name + return self.task_name \ No newline at end of file diff --git a/examples/fl_post/fl/project/src/nnunet_v1.py b/examples/fl_post/fl/project/src/nnunet_v1.py index 78869d4c1..2e5df028b 100644 --- a/examples/fl_post/fl/project/src/nnunet_v1.py +++ b/examples/fl_post/fl/project/src/nnunet_v1.py @@ -54,14 +54,17 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def train_nnunet(epochs, - current_epoch, +def train_nnunet(actual_max_num_epochs, + fl_round, + val_epoch=True, + train_epoch=True, + train_cutoff=np.inf, + val_cutoff=np.inf, network='3d_fullres', network_trainer='nnUNetTrainerV2', task='Task543_FakePostOpp_More', fold='0', continue_training=True, - validation_only=False, c=False, p=plans_param, use_compressed_data=False, @@ -78,9 +81,13 @@ def train_nnunet(epochs, pretrained_weights=None): """ + actual_max_num_epochs (int): Provides the number of epochs intended to be trained over the course of the whole federation (for lr scheduling) + (this needs to be held constant outside of individual calls to this function so that the lr is consistetly scheduled) + fl_round (int): Federated round, equal to the epoch used for the model (in lr scheduling) + val_epoch (bool) : Will validation be performed + train_epoch (bool) : Will training run (rather than val only) task (int): can be task name or task id fold: "0, 1, ..., 5 or 'all'" - validation_only: use this if you want to only run the validation c: use this if you want to continue a training p: plans identifier. Only change this if you created a custom experiment planner use_compressed_data: "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " @@ -132,7 +139,6 @@ def __init__(self, **kwargs): fold = args.fold network = args.network network_trainer = args.network_trainer - validation_only = args.validation_only plans_identifier = args.p find_lr = args.find_lr disable_postprocessing_on_folds = args.disable_postprocessing_on_folds @@ -198,6 +204,7 @@ def __init__(self, **kwargs): trainer = trainer_class( plans_file, fold, + actual_max_num_epochs=actual_max_num_epochs, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, @@ -206,6 +213,30 @@ def __init__(self, **kwargs): deterministic=deterministic, fp16=run_mixed_precision, ) + + + trainer.initialize(True) + + if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": + trainer.save_checkpoint( + join(trainer.output_folder, "model_final_checkpoint.model") + ) + print("Preparation round: Model-averaging") + return + + if find_lr: + trainer.find_lr(num_iters=self.actual_max_num_epochs) + else: + if args.continue_training: + # -c was set, continue a previous training and ignore pretrained weights + trainer.load_latest_checkpoint() + elif (not args.continue_training) and (args.pretrained_weights is not None): + # we start a new training. If pretrained_weights are set, use them + load_pretrained_weights(trainer.network, args.pretrained_weights) + else: + # new training without pretraine weights, do nothing + pass + # we want latest checkoint only (not best or any intermediate) trainer.save_final_checkpoint = ( True # whether or not to save the final checkpoint @@ -221,61 +252,44 @@ def __init__(self, **kwargs): trainer.save_latest_only = ( True # if false it will not store/overwrite _latest but separate files each ) - trainer.max_num_epochs = current_epoch + epochs - trainer.epoch = current_epoch - - # TODO: call validation separately - trainer.initialize(not validation_only) - - if os.getenv("PREP_INCREMENT_STEP", None) == "from_dataset_properties": - trainer.save_checkpoint( - join(trainer.output_folder, "model_final_checkpoint.model") - ) - print("Preparation round: Model-averaging") - return - - if find_lr: - trainer.find_lr() - else: - if not validation_only: - if args.continue_training: - # -c was set, continue a previous training and ignore pretrained weights - trainer.load_latest_checkpoint() - elif (not args.continue_training) and (args.pretrained_weights is not None): - # we start a new training. If pretrained_weights are set, use them - load_pretrained_weights(trainer.network, args.pretrained_weights) - else: - # new training without pretraine weights, do nothing - pass - - trainer.run_training() - else: - # if valbest: - # trainer.load_best_checkpoint(train=False) - # else: - # trainer.load_final_checkpoint(train=False) - trainer.load_latest_checkpoint() - trainer.network.eval() - - # if fold == "all": - # print("--> fold == 'all'") - # print("--> DONE") - # else: - # # predict validation - # trainer.validate( - # save_softmax=args.npz, - # validation_folder_name=val_folder, - # run_postprocessing_on_folds=not disable_postprocessing_on_folds, - # overwrite=args.val_disable_overwrite, - # ) - - # if network == "3d_lowres" and not args.disable_next_stage_pred: - # print("predicting segmentations for the next stage of the cascade") - # predict_next_stage( - # trainer, - # join( - # dataset_directory, - # trainer.plans["data_identifier"] + "_stage%d" % 1, - # ), - # ) + trainer.max_num_epochs = fl_round + 1 + trainer.epoch = fl_round + + # infer total data size and batch size in order to get how many batches to apply so that over many epochs, each data + # point is expected to be seen epochs number of times + + num_val_batches_per_epoch = int(np.ceil(len(trainer.dataset_val)/trainer.batch_size)) + num_train_batches_per_epoch = int(np.ceil(len(trainer.dataset_tr)/trainer.batch_size)) + + # the nnunet trainer attributes have a different naming convention than I am using + trainer.num_batches_per_epoch = num_train_batches_per_epoch + trainer.num_val_batches_per_epoch = num_val_batches_per_epoch + + batches_applied_train, \ + batches_applied_val, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = trainer.run_training(train_cutoff=train_cutoff, + val_cutoff=val_cutoff, + val_epoch=val_epoch, + train_epoch=train_epoch) + + train_completed = batches_applied_train / float(num_train_batches_per_epoch) + val_completed = batches_applied_val / float(num_val_batches_per_epoch) + + return train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 + + diff --git a/examples/fl_post/fl/project/src/runner_nnunetv1.py b/examples/fl_post/fl/project/src/runner_nnunetv1.py index 26f184522..3191d9a57 100644 --- a/examples/fl_post/fl/project/src/runner_nnunetv1.py +++ b/examples/fl_post/fl/project/src/runner_nnunetv1.py @@ -2,11 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards """ # TODO: Clean up imports -# TODO: ask Micah if this has to be changed (most probably no) import os import subprocess @@ -36,12 +35,15 @@ class PyTorchNNUNetCheckpointTaskRunner(PyTorchCheckpointTaskRunner): def __init__(self, nnunet_task=None, config_path=None, + actual_max_num_epochs=1000, **kwargs): """Initialize. Args: - config_path(str) : Path to the configuration file used by the training and validation script. - kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). + nnunet_task (str) : Task string used to identify the data and model folders + config_path(str) : Path to the configuration file used by the training and validation script. + actual_max_num_epochs (int) : Number of epochs for which this collaborator's model will be trained, should match the total rounds of federation in which this runner is participating + kwargs : Additional key work arguments (will be passed to rebuild_model, initialize_tensor_key_functions, TODO: ). TODO: """ @@ -72,6 +74,14 @@ def __init__(self, ) self.config_path = config_path + self.actual_max_num_epochs=actual_max_num_epochs + + # self.task_completed is a dictionary of task to amount completed as a float in [0,1] + # Values will be dynamically updated + # TODO: Tasks are hard coded for now + self.task_completed = {'aggregated_model_validation': 1.0, + 'train': 1.0, + 'locally_tuned_model_validation': 1.0} def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): @@ -108,11 +118,10 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): # get device for correct placement of tensors device = self.device - checkpoint_dict = self.load_checkpoint(map_location=device) + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load, map_location=device) epoch = checkpoint_dict['epoch'] new_state = {} # grabbing keys from the checkpoint state dict, poping from the tensor_dict - # Brandon DEBUGGING seen_keys = [] for k in checkpoint_dict['state_dict']: if k not in seen_keys: @@ -134,91 +143,113 @@ def write_tensors_into_checkpoint(self, tensor_dict, with_opt_vars): return epoch - def train(self, col_name, round_num, input_tensor_dict, epochs, **kwargs): + def train(self, col_name, round_num, input_tensor_dict, epochs, val_cutoff_time=np.inf, train_cutoff_time=np.inf, train_completion_dampener=0.0, **kwargs): # TODO: Figure out the right name to use for this method and the default assigner """Perform training for a specified number of epochs.""" self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) # 1. Insert tensor_dict info into checkpoint - current_epoch = self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) - # 2. Train function existing externally - # Some todo inside function below - # TODO: test for off-by-one error - # TODO: we need to disable validation if possible, and separately call validation - - # FIXME: we need to understand how to use round_num instead of current_epoch - # this will matter in straggler handling cases - # TODO: Should we put this in a separate process? - train_nnunet(epochs=epochs, current_epoch=current_epoch, task=self.data_loader.get_task_name()) - - # 3. Load metrics from checkpoint - (all_tr_losses, all_val_losses, all_val_losses_tr_mode, all_val_eval_metrics) = self.load_checkpoint()['plot_stuff'] - # these metrics are appended to the checkopint each epoch, so we select the most recent epoch - metrics = {'train_loss': all_tr_losses[-1], - 'val_eval': all_val_eval_metrics[-1]} - - return self.convert_results_to_tensorkeys(col_name, round_num, metrics) - - - - def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Training for round:{round_num}") + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, + fl_round=round_num, + train_cutoff=train_cutoff_time, + val_cutoff = val_cutoff_time, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=True) + + # dampen the train_completion """ - Run the trained model on validation data; report results. - - Parameters - ---------- - input_tensor_dict : either the last aggregated or locally trained model - - Returns - ------- - output_tensor_dict : {TensorKey: nparray} (these correspond to acc, - precision, f1_score, etc.) + values in range: (0, 1] with values near 0.0 making all train_completion rates shift nearer to 1.0, thus making the + trained model update weighting during aggregation stay closer to the plain data size weighting + specifically, update_weight = train_data_size / train_completed**train_completion_dampener """ + train_completed = train_completed**train_completion_dampener - raise NotImplementedError() - - """ - TBD - for now commenting out + # update amount of task completed + self.task_completed['train'] = train_completed + self.task_completed['locally_tuned_model_validation'] = val_completed - self.rebuild_model(round_num, input_tensor_dict, validation=True) + # 3. Prepare metrics + metrics = {'train_loss': this_ave_train_loss} - # 1. Save model in native format - self.save_native(self.mlcube_model_in_path) + global_tensor_dict, local_tensor_dict = self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=True) - # 2. Call MLCube validate task - platform_yaml = os.path.join(self.mlcube_dir, 'platforms', '{}.yaml'.format(self.mlcube_runner_type)) - task_yaml = os.path.join(self.mlcube_dir, 'run', 'evaluate.yaml') - proc = subprocess.run(["mlcube_docker", - "run", - "--mlcube={}".format(self.mlcube_dir), - "--platform={}".format(platform_yaml), - "--task={}".format(task_yaml)]) - - # 3. Load any metrics - metrics = self.load_metrics(os.path.join(self.mlcube_dir, 'workspace', 'metrics', 'evaluate_metrics.json')) - - # set the validation data size - sample_count = int(metrics.pop(self.evaluation_sample_count_key)) - self.data_loader.set_valid_data_size(sample_count) + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + return global_tensor_dict, local_tensor_dict + - # 4. Convert to tensorkeys - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' + def validate(self, col_name, round_num, input_tensor_dict, val_cutoff_time=np.inf, from_checkpoint=False, **kwargs): + # TODO: Figure out the right name to use for this method and the default assigner + """Perform validation.""" + + if not from_checkpoint: + self.rebuild_model(input_tensor_dict=input_tensor_dict, **kwargs) + # 1. Insert tensor_dict info into checkpoint + self.set_tensor_dict(tensor_dict=input_tensor_dict, with_opt_vars=False) + self.logger.info(f"Validating for round:{round_num}") + # 2. Train/val function existing externally + # Some todo inside function below + train_completed, \ + val_completed, \ + this_ave_train_loss, \ + this_ave_val_loss, \ + this_val_eval_metrics, \ + this_val_eval_metrics_C1, \ + this_val_eval_metrics_C2, \ + this_val_eval_metrics_C3, \ + this_val_eval_metrics_C4 = train_nnunet(actual_max_num_epochs=self.actual_max_num_epochs, + fl_round=round_num, + train_cutoff=0, + val_cutoff = val_cutoff_time, + task=self.data_loader.get_task_name(), + val_epoch=True, + train_epoch=False) + # double check + if train_completed != 0.0: + raise ValueError(f"Tried to validate only, but got a non-zero amount ({train_completed}) of training done.") + + # update amount of task completed + self.task_completed['aggregated_model_validation'] = val_completed + + self.logger.info(f"Completed train/val with {int(train_completed*100)}% of the train work and {int(val_completed*100)}% of the val work. Exact rates are: {train_completed} and {val_completed}") + + + # 3. Prepare metrics + metrics = {'val_eval': this_val_eval_metrics, + 'val_eval_C1': this_val_eval_metrics_C1, + 'val_eval_C2': this_val_eval_metrics_C2, + 'val_eval_C3': this_val_eval_metrics_C3, + 'val_eval_C4': this_val_eval_metrics_C4} else: - suffix += '_agg' - tags = ('metric', suffix) - output_tensor_dict = { - TensorKey( - metric_name, origin, round_num, True, tags - ): np.array(metrics[metric_name]) - for metric_name in metrics - } - - return output_tensor_dict, {} - - """ + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) + + all_tr_losses, \ + all_val_losses, \ + all_val_losses_tr_mode, \ + all_val_eval_metrics, \ + all_val_eval_metrics_C1, \ + all_val_eval_metrics_C2, \ + all_val_eval_metrics_C3, \ + all_val_eval_metrics_C4 = checkpoint_dict['plot_stuff'] + # these metrics are appended to the checkpoint each call to train, so it is critical that we are grabbing this right after + metrics = {'val_eval': all_val_eval_metrics[-1], + 'val_eval_C1': all_val_eval_metrics_C1[-1], + 'val_eval_C2': all_val_eval_metrics_C2[-1], + 'val_eval_C3': all_val_eval_metrics_C3[-1], + 'val_eval_C4': all_val_eval_metrics_C4[-1]} + + return self.convert_results_to_tensorkeys(col_name, round_num, metrics, insert_model=False) def load_metrics(self, filepath): @@ -230,4 +261,38 @@ def load_metrics(self, filepath): with open(filepath) as json_file: metrics = json.load(json_file) return metrics - """ \ No newline at end of file + """ + + + def get_train_data_size(self, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema + """ + if not task_name: + return self.data_loader.get_train_data_size() + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_train_data_size())) + + + def get_valid_data_size(self, task_name=None): + """Get the number of training examples. + + It will be used for weighted averaging in aggregation. + This overrides the parent class method, + allowing dynamic weighting by storing recent appropriate weights in class attributes. + + Returns: + int: The number of training examples, weighted by how much of the task got completed, then cast to int to satisy proto schema + """ + if not task_name: + return self.data_loader.get_valid_data_size() + else: + # self.task_completed is a dictionary of task_name to amount completed as a float in [0,1] + return int(np.ceil(self.task_completed[task_name]**(-1) * self.data_loader.get_valid_data_size())) diff --git a/examples/fl_post/fl/project/src/runner_pt_chkpt.py b/examples/fl_post/fl/project/src/runner_pt_chkpt.py index 6ab7851b9..a7fbd2056 100644 --- a/examples/fl_post/fl/project/src/runner_pt_chkpt.py +++ b/examples/fl_post/fl/project/src/runner_pt_chkpt.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Contributors: Micah Sheller, Patrick Foley, Brandon Edwards - DELETEME? +Contributors: Micah Sheller, Patrick Foley, Brandon Edwards """ # TODO: Clean up imports @@ -82,12 +82,10 @@ def __init__(self, self.replace_checkpoint(self.checkpoint_path_initial) - def load_checkpoint(self, checkpoint_path=None, map_location=None): + def load_checkpoint(self, checkpoint_path, map_location=None): """ Function used to load checkpoint from disk. """ - if not checkpoint_path: - checkpoint_path = self.checkpoint_path_load checkpoint_dict = torch.load(checkpoint_path, map_location=map_location) return checkpoint_dict @@ -124,7 +122,7 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): return self.required_tensorkeys_for_function[func_name] def reset_opt_vars(self): - current_checkpoint_dict = self.load_checkpoint() + current_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) initial_checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_initial) derived_opt_state_dict = self._get_optimizer_state(checkpoint_dict=initial_checkpoint_dict) self._set_optimizer_state(derived_opt_state_dict=derived_opt_state_dict, @@ -172,7 +170,7 @@ def read_tensors_from_checkpoint(self, with_opt_vars): dict: Tensor dictionary {**dict, **optimizer_dict} """ - checkpoint_dict = self.load_checkpoint() + checkpoint_dict = self.load_checkpoint(checkpoint_path=self.checkpoint_path_load) state = to_cpu_numpy(checkpoint_dict['state_dict']) if with_opt_vars: opt_state = self._get_optimizer_state(checkpoint_dict=checkpoint_dict) @@ -255,7 +253,9 @@ def _read_opt_state_from_checkpoint(self, checkpoint_dict): return derived_opt_state_dict - def convert_results_to_tensorkeys(self, col_name, round_num, metrics): + def convert_results_to_tensorkeys(self, col_name, round_num, metrics, insert_model): + # insert_model determined whether or not to include the model in the return dictionaries + # 5. Convert to tensorkeys # output metric tensors (scalar) @@ -268,11 +268,14 @@ def convert_results_to_tensorkeys(self, col_name, round_num, metrics): metrics[metric_name] ) for metric_name in metrics} - # output model tensors (Doesn't include TensorKey) - output_model_dict = self.get_tensor_dict(with_opt_vars=True) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, - tensor_dict=output_model_dict, - **self.tensor_dict_split_fn_kwargs) + if insert_model: + # output model tensors (Doesn't include TensorKey) + output_model_dict = self.get_tensor_dict(with_opt_vars=True) + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(logger=self.logger, + tensor_dict=output_model_dict, + **self.tensor_dict_split_fn_kwargs) + else: + global_model_dict, local_model_dict = {}, {} # create global tensorkeys global_tensorkey_model_dict = {