diff --git a/GANDLF/Configuration/Parameters/__init__.py b/GANDLF/Configuration/Parameters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/GANDLF/Configuration/Parameters/default_parameters.py b/GANDLF/Configuration/Parameters/default_parameters.py new file mode 100644 index 000000000..a55da612e --- /dev/null +++ b/GANDLF/Configuration/Parameters/default_parameters.py @@ -0,0 +1,67 @@ +from pydantic import BaseModel, Field +from typing import Optional, Dict, Set + +from typing_extensions import Union + + +class DefaultParameters(BaseModel): + weighted_loss: bool = Field( + default=False, description="Whether weighted loss is to be used or not." + ) + verbose: bool = Field(default=False, description="General application verbosity.") + q_verbose: bool = Field(default=False, description="Queue construction verbosity.") + medcam_enabled: bool = Field( + default=False, description="Enable interpretability via medcam." + ) + save_training: bool = Field( + default=False, description="Save outputs during training." + ) + save_output: bool = Field( + default=False, description="Save outputs during validation/testing." + ) + in_memory: bool = Field(default=False, description="Pin data to CPU memory.") + pin_memory_dataloader: bool = Field( + default=False, description="Pin data to GPU memory." + ) + scaling_factor: int = Field( + default=1, description="Scaling factor for regression problems." + ) + q_max_length: int = Field(default=100, description="The max length of the queue.") + q_samples_per_volume: int = Field( + default=10, description="Number of samples per volume." + ) + q_num_workers: int = Field( + default=4, description="Number of worker threads to use." + ) + num_epochs: int = Field(default=100, description="Total number of epochs to train.") + patience: int = Field( + default=100, description="Number of epochs to wait for performance improvement." + ) + batch_size: int = Field(default=1, description="Default batch size for training.") + learning_rate: float = Field(default=0.001, description="Default learning rate.") + clip_grad: Optional[float] = Field( + default=None, description="Gradient clipping value." + ) + track_memory_usage: bool = Field( + default=False, description="Enable memory usage tracking." + ) + memory_save_mode: bool = Field( + default=False, + description="Enable memory-saving mode. If enabled, resize/resample will save files to disk.", + ) + print_rgb_label_warning: bool = Field( + default=True, description="Print a warning for RGB labels." + ) + data_postprocessing: Union[dict, set] = Field( + default={}, description="Default data postprocessing configuration." + ) + grid_aggregator_overlap: str = Field( + default="crop", description="Default grid aggregator overlap strategy." + ) + determinism: bool = Field( + default=False, description="Enable deterministic computation." + ) + previous_parameters: Optional[Dict] = Field( + default=None, + description="Previous parameters to be used for resuming training and performing sanity checks.", + ) diff --git a/GANDLF/Configuration/Parameters/exclude_parameters.py b/GANDLF/Configuration/Parameters/exclude_parameters.py new file mode 100644 index 000000000..e6f079985 --- /dev/null +++ b/GANDLF/Configuration/Parameters/exclude_parameters.py @@ -0,0 +1 @@ +exclude_parameters = {"differential_privacy"} diff --git a/GANDLF/Configuration/Parameters/model_parameters.py b/GANDLF/Configuration/Parameters/model_parameters.py new file mode 100644 index 000000000..3f82c2720 --- /dev/null +++ b/GANDLF/Configuration/Parameters/model_parameters.py @@ -0,0 +1,69 @@ +from pydantic import BaseModel, model_validator, Field, AliasChoices, ConfigDict + +from typing_extensions import Self, Literal, Optional +from typing import Union +from GANDLF.Configuration.Parameters.validators import ( + validate_class_list, + validate_norm_type, +) +from GANDLF.models import global_models_dict + +# Define model architecture options +ARCHITECTURE_OPTIONS = Literal[tuple(global_models_dict.keys())] +NORM_TYPE_OPTIONS = Literal["batch", "instance", "none"] + + +# You can define new parameters for model here. Please read the pydantic documentation. +# It allows extra fields in model dict. +class Model(BaseModel): + model_config = ConfigDict( + extra="allow" + ) # it allows extra fields in the model dict + dimension: Optional[int] = Field(description="Dimension.") + architecture: Union[ARCHITECTURE_OPTIONS, dict] = Field(description="Architecture.") + final_layer: str = Field(description="Final layer.") + norm_type: Optional[NORM_TYPE_OPTIONS] = Field( + description="Normalization type.", default="batch" + ) # TODO: check it again + base_filters: Optional[int] = Field( + description="Base filters.", default=None, validate_default=True + ) # default is 32 + class_list: Union[list, str] = Field(default=[], description="Class list.") + num_channels: Optional[int] = Field( + description="Number of channels.", + validation_alias=AliasChoices( + "num_channels", "n_channels", "channels", "model_channels" + ), + default=3, + ) # TODO: check it + type: Optional[str] = Field(description="Type of model.", default="torch") + data_type: str = Field(description="Data type.", default="FP32") + save_at_every_epoch: bool = Field(default=False, description="Save at every epoch.") + amp: bool = Field(default=False, description="Amplifier.") + ignore_label_validation: Union[int, None] = Field( + default=None, description="Ignore label validation." + ) # TODO: To check it + print_summary: bool = Field(default=True, description="Print summary.") + + @model_validator(mode="after") + def model_validate(self) -> Self: + # TODO: Change the print to logging.warnings + self.class_list = validate_class_list( + self.class_list + ) # init and validate the class_list parameter + self.norm_type = validate_norm_type( + self.norm_type, self.architecture + ) # init and validate the norm type + if self.amp is False: + print("NOT using Mixed Precision Training") + + if self.save_at_every_epoch: + print( + "WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk." + ) # TODO: It is better to use logging.warning + + if self.base_filters is None: + self.base_filters = 32 + print("Using default 'base_filters' in 'model': ", self.base_filters) + + return self diff --git a/GANDLF/Configuration/Parameters/nested_training_parameters.py b/GANDLF/Configuration/Parameters/nested_training_parameters.py new file mode 100644 index 000000000..a7d7a049e --- /dev/null +++ b/GANDLF/Configuration/Parameters/nested_training_parameters.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self, Optional + + +class NestedTraining(BaseModel): + stratified: bool = Field( + default=False, + description="this will perform stratified k-fold cross-validation but only with offline data splitting", + ) + testing: int = Field( + default=-5, + description="this controls the number of testing data folds for final model evaluation; [NOT recommended] to disable this, use '1'", + ) + validation: int = Field( + default=-5, + description="this controls the number of validation data folds to be used for model *selection* during training (not used for back-propagation)", + ) + proportional: Optional[bool] = Field(default=None) + + @model_validator(mode="after") + def validate_nested_training(self) -> Self: + if self.proportional is not None: + self.stratified = self.proportional + return self diff --git a/GANDLF/Configuration/Parameters/optimizer_parameters.py b/GANDLF/Configuration/Parameters/optimizer_parameters.py new file mode 100644 index 000000000..eec57fb4f --- /dev/null +++ b/GANDLF/Configuration/Parameters/optimizer_parameters.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field +from typing_extensions import Literal +from GANDLF.optimizers import global_optimizer_dict + +# takes the keys from global optimizer +OPTIMIZER_OPTIONS = Literal[tuple(global_optimizer_dict.keys())] + + +class Optimizer(BaseModel): + type: OPTIMIZER_OPTIONS = Field(description="Type of optimizer to use") diff --git a/GANDLF/Configuration/Parameters/parameters.py b/GANDLF/Configuration/Parameters/parameters.py new file mode 100644 index 000000000..985e9f200 --- /dev/null +++ b/GANDLF/Configuration/Parameters/parameters.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, ConfigDict +from GANDLF.Configuration.Parameters.user_defined_parameters import ( + UserDefinedParameters, +) + + +class ParametersConfiguration(BaseModel): + model_config = ConfigDict(extra="allow") + + +class Parameters(ParametersConfiguration, UserDefinedParameters): + pass diff --git a/GANDLF/Configuration/Parameters/patch_sampler.py b/GANDLF/Configuration/Parameters/patch_sampler.py new file mode 100644 index 000000000..08bc31f25 --- /dev/null +++ b/GANDLF/Configuration/Parameters/patch_sampler.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, Field + + +class PatchSampler(BaseModel): + type: str = Field(default="uniform") + enable_padding: bool = Field(default=False) + padding_mode: str = Field(default="symmetric") + biased_sampling: bool = Field(default=False) diff --git a/GANDLF/Configuration/Parameters/scheduler_parameters.py b/GANDLF/Configuration/Parameters/scheduler_parameters.py new file mode 100644 index 000000000..96b110a48 --- /dev/null +++ b/GANDLF/Configuration/Parameters/scheduler_parameters.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Literal + +from GANDLF.schedulers import global_schedulers_dict + +TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())] + + +# It allows extra parameters +class Scheduler(BaseModel): + model_config = ConfigDict(extra="allow") + type: TYPE_OPTIONS = Field( + description="triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR" + ) + # min_lr: 0.00001, #TODO: this should be defined ?? + # max_lr: 1, #TODO: this should be defined ?? + step_size: float = Field(description="step_size", default=None) diff --git a/GANDLF/Configuration/Parameters/user_defined_parameters.py b/GANDLF/Configuration/Parameters/user_defined_parameters.py new file mode 100644 index 000000000..4e54ed52a --- /dev/null +++ b/GANDLF/Configuration/Parameters/user_defined_parameters.py @@ -0,0 +1,111 @@ +from typing import Union +from pydantic import BaseModel, model_validator, Field, AfterValidator +from GANDLF.Configuration.Parameters.default_parameters import DefaultParameters +from GANDLF.Configuration.Parameters.nested_training_parameters import NestedTraining +from GANDLF.Configuration.Parameters.patch_sampler import PatchSampler +from GANDLF.utils import version_check +from importlib.metadata import version +from typing_extensions import Self, Literal, Annotated, Any +from GANDLF.Configuration.Parameters.validators import * +from GANDLF.Configuration.Parameters.model_parameters import Model + + +class Version(BaseModel): # TODO: Maybe should be to another folder + minimum: str + maximum: str + + @model_validator(mode="after") + def validate_version(self) -> Self: + if version_check(self.model_dump(), version_to_check=version("GANDLF")): + return self + + +class InferenceMechanism(BaseModel): + grid_aggregator_overlap: Literal["crop", "average"] = Field(default="crop") + patch_overlap: int = Field(default=0) + + +class UserDefinedParameters(DefaultParameters): + version: Version = Field( + default=Version(minimum=version("GANDLF"), maximum=version("GANDLF")), + description="Whether weighted loss is to be used or not.", + ) + patch_size: Union[list[Union[int, float]], int, float] = Field( + description="Patch size." + ) + model: Model = Field(..., description="The model to use. ") + modality: Literal["rad", "histo", "path"] = Field(description="Modality.") + loss_function: Annotated[ + Union[dict, str], + Field(description="Loss function."), + AfterValidator(validate_loss_function), + ] + metrics: Annotated[ + Union[dict, list[Union[str, dict, set]]], + Field(description="Metrics."), + AfterValidator(validate_metrics), + ] + nested_training: NestedTraining = Field(description="Nested training.") + parallel_compute_command: str = Field( + default="", description="Parallel compute command." + ) + scheduler: Union[str, Scheduler] = Field( + description="Scheduler.", default=Scheduler(type="triangle_modified") + ) + optimizer: Union[str, Optimizer] = Field( + description="Optimizer.", default=Optimizer(type="adam") + ) # TODO: Check it again for (opt) + patch_sampler: Union[str, PatchSampler] = Field( + description="Patch sampler.", default=PatchSampler() + ) + inference_mechanism: InferenceMechanism = Field( + description="Inference mechanism.", default=InferenceMechanism() + ) + data_postprocessing_after_reverse_one_hot_encoding: dict = Field( + description="data_postprocessing_after_reverse_one_hot_encoding.", default={} + ) + differential_privacy: Any = Field(description="Differential privacy.", default=None) + # TODO: It should be defined with a better way (using a BaseModel class) + data_preprocessing: Annotated[ + dict, + Field(description="Data preprocessing."), + AfterValidator(validate_data_preprocessing), + ] = {} + # TODO: It should be defined with a better way (using a BaseModel class) + data_augmentation: Annotated[dict, Field(description="Data augmentation.")] = {} + + # Validators + @model_validator(mode="after") + def validate(self) -> Self: + # validate the patch_size + self.patch_size, self.model.dimension = validate_patch_size( + self.patch_size, self.model.dimension + ) + # validate the parallel_compute_command + self.parallel_compute_command = validate_parallel_compute_command( + self.parallel_compute_command + ) + # validate scheduler + self.scheduler = validate_schedular(self.scheduler, self.learning_rate) + # validate optimizer + self.optimizer = validate_optimizer(self.optimizer) + # validate patch_sampler + self.patch_sampler = validate_patch_sampler(self.patch_sampler) + # validate_data_augmentation + self.data_augmentation = validate_data_augmentation( + self.data_augmentation, self.patch_size + ) + # validate data_postprocessing_after_reverse_one_hot_encoding + ( + self.data_postprocessing_after_reverse_one_hot_encoding, + self.data_postprocessing, + ) = validate_data_postprocessing_after_reverse_one_hot_encoding( + self.data_postprocessing_after_reverse_one_hot_encoding, + self.data_postprocessing, + ) + # validate differential_privacy + self.differential_privacy = validate_differential_privacy( + self.differential_privacy, self.batch_size + ) + + return self diff --git a/GANDLF/Configuration/Parameters/validators.py b/GANDLF/Configuration/Parameters/validators.py new file mode 100644 index 000000000..5308e12c9 --- /dev/null +++ b/GANDLF/Configuration/Parameters/validators.py @@ -0,0 +1,437 @@ +import traceback +from copy import deepcopy +from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding + +import numpy as np +import sys + +from GANDLF.Configuration.Parameters.optimizer_parameters import Optimizer +from GANDLF.Configuration.Parameters.patch_sampler import PatchSampler +from GANDLF.Configuration.Parameters.scheduler_parameters import Scheduler +from GANDLF.Configuration.utils import initialize_key +from GANDLF.metrics import surface_distance_ids + + +def validate_loss_function(value) -> dict: + if isinstance(value, dict): # if this is a dict + if len(value) > 0: # only proceed if something is defined + for key in value: # iterate through all keys + if key == "mse": + if (value[key] is None) or not ("reduction" in value[key]): + value[key] = {} + value[key]["reduction"] = "mean" + else: + # use simple string for other functions - can be extended with parameters, if needed + value = key + else: + if value == "focal": + value = {"focal": {}} + value["focal"]["gamma"] = 2.0 + value["focal"]["size_average"] = True + elif value == "mse": + value = {"mse": {}} + value["mse"]["reduction"] = "mean" + + return value + + +def validate_metrics(value) -> dict: + if not isinstance(value, dict): + temp_dict = {} + else: + temp_dict = value + + # initialize metrics dict + for metric in value: + # assigning a new variable because some metrics can be dicts, and we want to get the first key + comparison_string = metric + if isinstance(metric, dict): + comparison_string = list(metric.keys())[0] + # these metrics always need to be dicts + if comparison_string in [ + "accuracy", + "f1", + "precision", + "recall", + "specificity", + "iou", + ]: + if not isinstance(metric, dict): + temp_dict[metric] = {} + else: + temp_dict[comparison_string] = metric + elif not isinstance(metric, dict): + temp_dict[metric] = None + + # special case for accuracy, precision, recall, and specificity; which could be dicts + ## need to find a better way to do this + if any( + _ in comparison_string + for _ in ["precision", "recall", "specificity", "accuracy", "f1"] + ): + if comparison_string != "classification_accuracy": + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "average", "weighted" + ) + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "multi_class", True + ) + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "mdmc_average", "samplewise" + ) + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "threshold", 0.5 + ) + if comparison_string == "accuracy": + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "subset_accuracy", False + ) + elif "iou" in comparison_string: + temp_dict["iou"] = initialize_key( + temp_dict["iou"], "reduction", "elementwise_mean" + ) + temp_dict["iou"] = initialize_key(temp_dict["iou"], "threshold", 0.5) + elif comparison_string in surface_distance_ids: + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "connectivity", 1 + ) + temp_dict[comparison_string] = initialize_key( + temp_dict[comparison_string], "threshold", None + ) + + value = temp_dict + return value + + +def validate_class_list(value): + if isinstance(value, str): + if ("||" in value) or ("&&" in value): + # special case for multi-class computation - this needs to be handled during one-hot encoding mask construction + print( + "WARNING: This is a special case for multi-class computation, where different labels are processed together, `reverse_one_hot` will need mapping information to work correctly" + ) + temp_class_list = value + # we don't need the brackets + temp_class_list = temp_class_list.replace("[", "") + temp_class_list = temp_class_list.replace("]", "") + value = temp_class_list.split(",") + else: + try: + value = eval(value) + return value + except Exception as e: + ## todo: ensure logging captures assertion errors + assert ( + False + ), f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}" + # logging.error( + # f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}" + # ) + return value + + +def validate_patch_size(patch_size, dimension) -> list: + if isinstance(patch_size, int) or isinstance(patch_size, float): + patch_size = [patch_size] + if len(patch_size) == 1 and dimension is not None: + actual_patch_size = [] + for _ in range(dimension): + actual_patch_size.append(patch_size[0]) + patch_size = actual_patch_size + if len(patch_size) == 2: # 2d check + # ensuring same size during torchio processing + patch_size.append(1) + if dimension is None: + dimension = 2 + elif len(patch_size) == 3: # 2d check + if dimension is None: + dimension = 3 + return [patch_size, dimension] + + +def validate_norm_type(norm_type, architecture): + if norm_type is None or norm_type.lower() == "none": + if not ("vgg" in architecture): + raise ValueError( + "Normalization type cannot be 'None' for non-VGG architectures" + ) + return norm_type + + +def validate_parallel_compute_command(value): + parallel_compute_command = value + parallel_compute_command = parallel_compute_command.replace( + "'", "" + ) # TODO: Check it again,should change from ' to ` + parallel_compute_command = parallel_compute_command.replace('"', "") + value = parallel_compute_command + return value + + +def validate_schedular(value, learning_rate): + if isinstance(value, str): + value = Scheduler(type=value) + if value.step_size is None: + value.step_size = learning_rate / 5.0 + return value + + +def validate_optimizer(value): + if isinstance(value, str): + value = Optimizer(type=value) + return value + + +def validate_data_preprocessing(value) -> dict: + if not (value is None): + # perform this only when pre-processing is defined + if len(value) > 0: + thresholdOrClip = False + # this can be extended, as required + thresholdOrClipDict = ["threshold", "clip", "clamp"] + + resize_requested = False + temp_dict = deepcopy(value) + for key in value: + if key in ["resize", "resize_image", "resize_images", "resize_patch"]: + resize_requested = True + + if key in ["resample_min", "resample_minimum"]: + if "resolution" in value[key]: + resize_requested = True + resolution_temp = np.array(value[key]["resolution"]) + if resolution_temp.size == 1: + temp_dict[key]["resolution"] = np.array( + [resolution_temp, resolution_temp] + ).tolist() + else: + temp_dict.pop(key) + + value = temp_dict + + if resize_requested and "resample" in value: + for key in ["resize", "resize_image", "resize_images", "resize_patch"]: + if key in value: + value.pop(key) + + print( + "WARNING: Different 'resize' operations are ignored as 'resample' is defined under 'data_processing'", + file=sys.stderr, + ) + + # iterate through all keys + for key in value: # iterate through all keys + if key in thresholdOrClipDict: + # we only allow one of threshold or clip to occur and not both + assert not ( + thresholdOrClip + ), "Use only `threshold` or `clip`, not both" + thresholdOrClip = True + # initialize if nothing is present + if not (isinstance(value[key], dict)): + value[key] = {} + + # if one of the required parameters is not present, initialize with lowest/highest possible values + # this ensures the absence of a field doesn't affect processing + # for threshold or clip, ensure min and max are defined + if not "min" in value[key]: + value[key]["min"] = sys.float_info.min + if not "max" in value[key]: + value[key]["max"] = sys.float_info.max + + if key == "histogram_matching": + if value[key] is not False: + if not (isinstance(value[key], dict)): + value[key] = {} + + if key == "histogram_equalization": + if value[key] is not False: + # if histogram equalization is enabled, call histogram_matching + value["histogram_matching"] = {} + + if key == "adaptive_histogram_equalization": + if value[key] is not False: + # if histogram equalization is enabled, call histogram_matching + value["histogram_matching"] = {"target": "adaptive"} + return value + + +def validate_data_postprocessing_after_reverse_one_hot_encoding( + value, data_postprocessing +) -> list: + temp_dict = deepcopy(value) + for key in temp_dict: + if key in postprocessing_after_reverse_one_hot_encoding: + value[key] = data_postprocessing[key] + data_postprocessing.pop(key) + return [value, data_postprocessing] + + +def validate_patch_sampler(value): + if isinstance(value, str): + value = PatchSampler(type=value.lower()) + return value + + +def validate_data_augmentation(value, patch_size) -> dict: + value["default_probability"] = value.get("default_probability", 0.5) + if not (value is None): + if len(value) > 0: # only when augmentations are defined + # special case for random swapping and elastic transformations - which takes a patch size for computation + for key in ["swap", "elastic"]: + if key in value: + value[key] = initialize_key( + value[key], + "patch_size", + np.round(np.array(patch_size) / 10).astype("int").tolist(), + ) + + # special case for swap default initialization + if "swap" in value: + value["swap"] = initialize_key(value["swap"], "num_iterations", 100) + + # special case for affine default initialization + if "affine" in value: + value["affine"] = initialize_key(value["affine"], "scales", 0.1) + value["affine"] = initialize_key(value["affine"], "degrees", 15) + value["affine"] = initialize_key(value["affine"], "translation", 2) + + if "motion" in value: + value["motion"] = initialize_key(value["motion"], "num_transforms", 2) + value["motion"] = initialize_key(value["motion"], "degrees", 15) + value["motion"] = initialize_key(value["motion"], "translation", 2) + value["motion"] = initialize_key( + value["motion"], "interpolation", "linear" + ) + + # special case for random blur/noise - which takes a std-dev range + for std_aug in ["blur", "noise_var"]: + if std_aug in value: + value[std_aug] = initialize_key(value[std_aug], "std", None) + for std_aug in ["noise"]: + if std_aug in value: + value[std_aug] = initialize_key(value[std_aug], "std", [0, 1]) + + # special case for random noise - which takes a mean range + for mean_aug in ["noise", "noise_var"]: + if mean_aug in value: + value[mean_aug] = initialize_key(value[mean_aug], "mean", 0) + + # special case for augmentations that need axis defined + for axis_aug in ["flip", "anisotropic", "rotate_90", "rotate_180"]: + if axis_aug in value: + value[axis_aug] = initialize_key(value[axis_aug], "axis", [0, 1, 2]) + + # special case for colorjitter + if "colorjitter" in value: + value = initialize_key(value, "colorjitter", {}) + for key in ["brightness", "contrast", "saturation"]: + value["colorjitter"] = initialize_key( + value["colorjitter"], key, [0, 1] + ) + value["colorjitter"] = initialize_key( + value["colorjitter"], "hue", [-0.5, 0.5] + ) + + # Added HED augmentation in gandlf + hed_augmentation_types = [ + "hed_transform", + # "hed_transform_light", + # "hed_transform_heavy", + ] + for augmentation_type in hed_augmentation_types: + if augmentation_type in value: + value = initialize_key(value, "hed_transform", {}) + ranges = [ + "haematoxylin_bias_range", + "eosin_bias_range", + "dab_bias_range", + "haematoxylin_sigma_range", + "eosin_sigma_range", + "dab_sigma_range", + ] + + default_range = ( + [-0.1, 0.1] + if augmentation_type == "hed_transform" + else ( + [-0.03, 0.03] + if augmentation_type == "hed_transform_light" + else [-0.95, 0.95] + ) + ) + + for key in ranges: + value["hed_transform"] = initialize_key( + value["hed_transform"], key, default_range + ) + + value["hed_transform"] = initialize_key( + value["hed_transform"], "cutoff_range", [0, 1] + ) + + # special case for anisotropic + if "anisotropic" in value: + if not ("downsampling" in value["anisotropic"]): + default_downsampling = 1.5 + else: + default_downsampling = value["anisotropic"]["downsampling"] + + initialize_downsampling = False + if isinstance(default_downsampling, list): + if len(default_downsampling) != 2: + initialize_downsampling = True + print( + "WARNING: 'anisotropic' augmentation needs to be either a single number of a list of 2 numbers: https://torchio.readthedocs.io/transforms/augmentation.html?highlight=randomswap#torchio.transforms.RandomAnisotropy.", + file=sys.stderr, + ) + default_downsampling = default_downsampling[0] # only + else: + initialize_downsampling = True + + if initialize_downsampling: + if default_downsampling < 1: + print( + "WARNING: 'anisotropic' augmentation needs the 'downsampling' parameter to be greater than 1, defaulting to 1.5.", + file=sys.stderr, + ) + # default + value["anisotropic"]["downsampling"] = 1.5 + + for key in value: + if key != "default_probability": + value[key] = initialize_key( + value[key], "probability", value["default_probability"] + ) + return value + + +def validate_differential_privacy(value, batch_size): + if value is None: + return value + if not isinstance(value, dict): + print( + "WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary." + ) + value = {} + # these are some defaults + value = initialize_key(value, "noise_multiplier", 10.0) + value = initialize_key(value, "max_grad_norm", 1.0) + value = initialize_key(value, "accountant", "rdp") + value = initialize_key(value, "secure_mode", False) + value = initialize_key(value, "allow_opacus_model_fix", True) + value = initialize_key(value, "delta", 1e-5) + value = initialize_key(value, "physical_batch_size", batch_size) + + if value["physical_batch_size"] > batch_size: + print( + f"WARNING: The physical batch size {value['physical_batch_size']} is greater" + f"than the batch size {batch_size}, setting the physical batch size to the batch size." + ) + value["physical_batch_size"] = batch_size + + # these keys need to be parsed as floats, not strings + for key in ["noise_multiplier", "max_grad_norm", "delta", "epsilon"]: + if key in value: + value[key] = float(value[key]) + return value diff --git a/GANDLF/Configuration/__init__.py b/GANDLF/Configuration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/GANDLF/Configuration/utils.py b/GANDLF/Configuration/utils.py new file mode 100644 index 000000000..729e16ab7 --- /dev/null +++ b/GANDLF/Configuration/utils.py @@ -0,0 +1,67 @@ +from typing import Optional, Union + + +from typing import Type +from pydantic import BaseModel + + +def generate_and_save_markdown(model: Type[BaseModel], file_path: str) -> None: + schema = model.schema() + markdown = [] + + # Add title + markdown.append(f"# {schema['title']}\n") + + # Add description if available + if "description" in schema: + markdown.append(f"{schema['description']}\n") + + # Add fields table + markdown.append("## Parameters\n") + markdown.append("| Field | Type | Description | Default |") + markdown.append( + "|----------------|----------------|-----------------------|------------------|" + ) + + for field_name, field_info in schema["properties"].items(): + # Extract field details + field_type = field_info.get("type", "N/A") + description = field_info.get("description", "N/A") + default = field_info.get("default", "N/A") + + # Add row to the table + markdown.append( + f"| `{field_name}` | `{field_type}` | {description} | `{default}` |" + ) + + # Write to file + with open(file_path, "w", encoding="utf-8") as file: + file.write("\n".join(markdown)) + + +def initialize_key( + parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None +) -> dict: + """ + This function initializes a key in the parameters dictionary to a value if it is absent. + + Args: + parameters (dict): The parameter dictionary. + key (str): The key to initialize. + value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None. + + Returns: + dict: The parameter dictionary. + """ + if parameters is None: + parameters = {} + if key in parameters: + if parameters[key] is not None: + if isinstance(parameters[key], dict): + # if key is present but not defined + if len(parameters[key]) == 0: + parameters[key] = value + else: + parameters[key] = value # if key is absent + + return parameters diff --git a/GANDLF/config_manager.py b/GANDLF/config_manager.py index ed26ec8e1..461b8b064 100644 --- a/GANDLF/config_manager.py +++ b/GANDLF/config_manager.py @@ -1,113 +1,12 @@ # import logging import traceback from typing import Optional, Union -import sys, yaml, ast -import numpy as np -from copy import deepcopy +from pydantic import ValidationError +import yaml -from .utils import version_check -from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding -from GANDLF.privacy.opacus import parse_opacus_params -from GANDLF.metrics import surface_distance_ids -from importlib.metadata import version - -## dictionary to define defaults for appropriate options, which are evaluated -parameter_defaults = { - "weighted_loss": False, # whether weighted loss is to be used or not - "verbose": False, # general application verbosity - "q_verbose": False, # queue construction verbosity - "medcam_enabled": False, # interpretability via medcam - "save_training": False, # save outputs during training - "save_output": False, # save outputs during validation/testing - "in_memory": False, # pin data to cpu memory - "pin_memory_dataloader": False, # pin data to gpu memory - "scaling_factor": 1, # scaling factor for regression problems - "q_max_length": 100, # the max length of queue - "q_samples_per_volume": 10, # number of samples per volume - "q_num_workers": 4, # number of worker threads to use - "num_epochs": 100, # total number of epochs to train - "patience": 100, # number of epochs to wait for performance improvement - "batch_size": 1, # default batch size of training - "learning_rate": 0.001, # default learning rate - "clip_grad": None, # clip_gradient value - "track_memory_usage": False, # default memory tracking - "memory_save_mode": False, # default memory saving, if enabled, resize/resample will save files to disk - "print_rgb_label_warning": True, # print rgb label warning - "data_postprocessing": {}, # default data postprocessing - "grid_aggregator_overlap": "crop", # default grid aggregator overlap strategy - "determinism": False, # using deterministic version of computation - "previous_parameters": None, # previous parameters to be used for resuming training and perform sanity checking -} - -## dictionary to define string defaults for appropriate options -parameter_defaults_string = { - "optimizer": "adam", # the optimizer - "scheduler": "triangle_modified", # the default scheduler - "clip_mode": None, # default clip mode -} - - -def initialize_parameter( - params: dict, - parameter_to_initialize: str, - value: Optional[Union[str, list, int, dict]] = None, - evaluate: Optional[bool] = True, -) -> dict: - """ - This function will initialize the parameter in the parameters dict to the value if it is absent. - - Args: - params (dict): The parameter dictionary. - parameter_to_initialize (str): The parameter to initialize. - value (Optional[Union[str, list, int, dict]], optional): The value to initialize. Defaults to None. - evaluate (Optional[bool], optional): Whether to evaluate the value. Defaults to True. - - Returns: - dict: The parameter dictionary. - """ - if parameter_to_initialize in params: - if evaluate: - if isinstance(params[parameter_to_initialize], str): - if params[parameter_to_initialize].lower() == "none": - params[parameter_to_initialize] = ast.literal_eval( - params[parameter_to_initialize] - ) - else: - print( - "WARNING: Initializing '" + parameter_to_initialize + "' as " + str(value) - ) - params[parameter_to_initialize] = value - - return params - - -def initialize_key( - parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None -) -> dict: - """ - This function initializes a key in the parameters dictionary to a value if it is absent. - - Args: - parameters (dict): The parameter dictionary. - key (str): The key to initialize. - value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None. - - Returns: - dict: The parameter dictionary. - """ - if parameters is None: - parameters = {} - if key in parameters: - if parameters[key] is not None: - if isinstance(parameters[key], dict): - # if key is present but not defined - if len(parameters[key]) == 0: - parameters[key] = value - else: - parameters[key] = value # if key is absent - - return parameters +from GANDLF.Configuration.Parameters.parameters import Parameters +from GANDLF.Configuration.Parameters.exclude_parameters import exclude_parameters def _parseConfig( @@ -127,615 +26,12 @@ def _parseConfig( if not isinstance(config_file_path, dict): params = yaml.safe_load(open(config_file_path, "r")) - if version_check_flag: # this is only to be used for testing - assert ( - "version" in params - ), "The 'version' key needs to be defined in config with 'minimum' and 'maximum' fields to determine the compatibility of configuration with code base" - version_check(params["version"], version_to_check=version("GANDLF")) - - if "patch_size" in params: - # duplicate patch size if it is an int or float - if isinstance(params["patch_size"], int) or isinstance( - params["patch_size"], float - ): - params["patch_size"] = [params["patch_size"]] - # in case someone decides to pass a single value list - if len(params["patch_size"]) == 1: - actual_patch_size = [] - for _ in range(params["model"]["dimension"]): - actual_patch_size.append(params["patch_size"][0]) - params["patch_size"] = actual_patch_size - - # parse patch size as needed for computations - if len(params["patch_size"]) == 2: # 2d check - # ensuring same size during torchio processing - params["patch_size"].append(1) - if "dimension" not in params["model"]: - params["model"]["dimension"] = 2 - elif len(params["patch_size"]) == 3: # 2d check - if "dimension" not in params["model"]: - params["model"]["dimension"] = 3 - assert "patch_size" in params, "Patch size needs to be defined in the config file" - - if "resize" in params: - print( - "WARNING: 'resize' should be defined under 'data_processing', this will be skipped", - file=sys.stderr, - ) - - assert "modality" in params, "'modality' needs to be defined in the config file" - params["modality"] = params["modality"].lower() - assert params["modality"] in [ - "rad", - "histo", - "path", - ], "Modality should be either 'rad' or 'path'" - - assert ( - "loss_function" in params - ), "'loss_function' needs to be defined in the config file" - if "loss_function" in params: - # check if user has passed a dict - if isinstance(params["loss_function"], dict): # if this is a dict - if len(params["loss_function"]) > 0: # only proceed if something is defined - for key in params["loss_function"]: # iterate through all keys - if key == "mse": - if (params["loss_function"][key] is None) or not ( - "reduction" in params["loss_function"][key] - ): - params["loss_function"][key] = {} - params["loss_function"][key]["reduction"] = "mean" - else: - # use simple string for other functions - can be extended with parameters, if needed - params["loss_function"] = key - else: - # check if user has passed a single string - if params["loss_function"] == "mse": - params["loss_function"] = {} - params["loss_function"]["mse"] = {} - params["loss_function"]["mse"]["reduction"] = "mean" - elif params["loss_function"] == "focal": - params["loss_function"] = {} - params["loss_function"]["focal"] = {} - params["loss_function"]["focal"]["gamma"] = 2.0 - params["loss_function"]["focal"]["size_average"] = True - - assert "metrics" in params, "'metrics' needs to be defined in the config file" - if "metrics" in params: - if not isinstance(params["metrics"], dict): - temp_dict = {} - else: - temp_dict = params["metrics"] - - # initialize metrics dict - for metric in params["metrics"]: - # assigning a new variable because some metrics can be dicts, and we want to get the first key - comparison_string = metric - if isinstance(metric, dict): - comparison_string = list(metric.keys())[0] - # these metrics always need to be dicts - if comparison_string in [ - "accuracy", - "f1", - "precision", - "recall", - "specificity", - "iou", - ]: - if not isinstance(metric, dict): - temp_dict[metric] = {} - else: - temp_dict[comparison_string] = metric - elif not isinstance(metric, dict): - temp_dict[metric] = None - - # special case for accuracy, precision, recall, and specificity; which could be dicts - ## need to find a better way to do this - if any( - _ in comparison_string - for _ in ["precision", "recall", "specificity", "accuracy", "f1"] - ): - if comparison_string != "classification_accuracy": - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "average", "weighted" - ) - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "multi_class", True - ) - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "mdmc_average", "samplewise" - ) - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "threshold", 0.5 - ) - if comparison_string == "accuracy": - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "subset_accuracy", False - ) - elif "iou" in comparison_string: - temp_dict["iou"] = initialize_key( - temp_dict["iou"], "reduction", "elementwise_mean" - ) - temp_dict["iou"] = initialize_key(temp_dict["iou"], "threshold", 0.5) - elif comparison_string in surface_distance_ids: - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "connectivity", 1 - ) - temp_dict[comparison_string] = initialize_key( - temp_dict[comparison_string], "threshold", None - ) - - params["metrics"] = temp_dict - - # this is NOT a required parameter - a user should be able to train with NO augmentations - params = initialize_key(params, "data_augmentation", {}) - # for all others, ensure probability is present - params["data_augmentation"]["default_probability"] = params[ - "data_augmentation" - ].get("default_probability", 0.5) - - if not (params["data_augmentation"] is None): - if len(params["data_augmentation"]) > 0: # only when augmentations are defined - # special case for random swapping and elastic transformations - which takes a patch size for computation - for key in ["swap", "elastic"]: - if key in params["data_augmentation"]: - params["data_augmentation"][key] = initialize_key( - params["data_augmentation"][key], - "patch_size", - np.round(np.array(params["patch_size"]) / 10) - .astype("int") - .tolist(), - ) - - # special case for swap default initialization - if "swap" in params["data_augmentation"]: - params["data_augmentation"]["swap"] = initialize_key( - params["data_augmentation"]["swap"], "num_iterations", 100 - ) - - # special case for affine default initialization - if "affine" in params["data_augmentation"]: - params["data_augmentation"]["affine"] = initialize_key( - params["data_augmentation"]["affine"], "scales", 0.1 - ) - params["data_augmentation"]["affine"] = initialize_key( - params["data_augmentation"]["affine"], "degrees", 15 - ) - params["data_augmentation"]["affine"] = initialize_key( - params["data_augmentation"]["affine"], "translation", 2 - ) - - if "motion" in params["data_augmentation"]: - params["data_augmentation"]["motion"] = initialize_key( - params["data_augmentation"]["motion"], "num_transforms", 2 - ) - params["data_augmentation"]["motion"] = initialize_key( - params["data_augmentation"]["motion"], "degrees", 15 - ) - params["data_augmentation"]["motion"] = initialize_key( - params["data_augmentation"]["motion"], "translation", 2 - ) - params["data_augmentation"]["motion"] = initialize_key( - params["data_augmentation"]["motion"], "interpolation", "linear" - ) - - # special case for random blur/noise - which takes a std-dev range - for std_aug in ["blur", "noise_var"]: - if std_aug in params["data_augmentation"]: - params["data_augmentation"][std_aug] = initialize_key( - params["data_augmentation"][std_aug], "std", None - ) - for std_aug in ["noise"]: - if std_aug in params["data_augmentation"]: - params["data_augmentation"][std_aug] = initialize_key( - params["data_augmentation"][std_aug], "std", [0, 1] - ) - - # special case for random noise - which takes a mean range - for mean_aug in ["noise", "noise_var"]: - if mean_aug in params["data_augmentation"]: - params["data_augmentation"][mean_aug] = initialize_key( - params["data_augmentation"][mean_aug], "mean", 0 - ) - - # special case for augmentations that need axis defined - for axis_aug in ["flip", "anisotropic", "rotate_90", "rotate_180"]: - if axis_aug in params["data_augmentation"]: - params["data_augmentation"][axis_aug] = initialize_key( - params["data_augmentation"][axis_aug], "axis", [0, 1, 2] - ) - - # special case for colorjitter - if "colorjitter" in params["data_augmentation"]: - params["data_augmentation"] = initialize_key( - params["data_augmentation"], "colorjitter", {} - ) - for key in ["brightness", "contrast", "saturation"]: - params["data_augmentation"]["colorjitter"] = initialize_key( - params["data_augmentation"]["colorjitter"], key, [0, 1] - ) - params["data_augmentation"]["colorjitter"] = initialize_key( - params["data_augmentation"]["colorjitter"], "hue", [-0.5, 0.5] - ) - - # Added HED augmentation in gandlf - hed_augmentation_types = [ - "hed_transform", - # "hed_transform_light", - # "hed_transform_heavy", - ] - for augmentation_type in hed_augmentation_types: - if augmentation_type in params["data_augmentation"]: - params["data_augmentation"] = initialize_key( - params["data_augmentation"], "hed_transform", {} - ) - ranges = [ - "haematoxylin_bias_range", - "eosin_bias_range", - "dab_bias_range", - "haematoxylin_sigma_range", - "eosin_sigma_range", - "dab_sigma_range", - ] - - default_range = ( - [-0.1, 0.1] - if augmentation_type == "hed_transform" - else ( - [-0.03, 0.03] - if augmentation_type == "hed_transform_light" - else [-0.95, 0.95] - ) - ) - - for key in ranges: - params["data_augmentation"]["hed_transform"] = initialize_key( - params["data_augmentation"]["hed_transform"], - key, - default_range, - ) - - params["data_augmentation"]["hed_transform"] = initialize_key( - params["data_augmentation"]["hed_transform"], - "cutoff_range", - [0, 1], - ) - - # special case for anisotropic - if "anisotropic" in params["data_augmentation"]: - if not ("downsampling" in params["data_augmentation"]["anisotropic"]): - default_downsampling = 1.5 - else: - default_downsampling = params["data_augmentation"]["anisotropic"][ - "downsampling" - ] - - initialize_downsampling = False - if isinstance(default_downsampling, list): - if len(default_downsampling) != 2: - initialize_downsampling = True - print( - "WARNING: 'anisotropic' augmentation needs to be either a single number of a list of 2 numbers: https://torchio.readthedocs.io/transforms/augmentation.html?highlight=randomswap#torchio.transforms.RandomAnisotropy.", - file=sys.stderr, - ) - default_downsampling = default_downsampling[0] # only - else: - initialize_downsampling = True - - if initialize_downsampling: - if default_downsampling < 1: - print( - "WARNING: 'anisotropic' augmentation needs the 'downsampling' parameter to be greater than 1, defaulting to 1.5.", - file=sys.stderr, - ) - # default - params["data_augmentation"]["anisotropic"]["downsampling"] = 1.5 - - for key in params["data_augmentation"]: - if key != "default_probability": - params["data_augmentation"][key] = initialize_key( - params["data_augmentation"][key], - "probability", - params["data_augmentation"]["default_probability"], - ) - - # this is NOT a required parameter - a user should be able to train with NO built-in pre-processing - params = initialize_key(params, "data_preprocessing", {}) - if not (params["data_preprocessing"] is None): - # perform this only when pre-processing is defined - if len(params["data_preprocessing"]) > 0: - thresholdOrClip = False - # this can be extended, as required - thresholdOrClipDict = ["threshold", "clip", "clamp"] - - resize_requested = False - temp_dict = deepcopy(params["data_preprocessing"]) - for key in params["data_preprocessing"]: - if key in ["resize", "resize_image", "resize_images", "resize_patch"]: - resize_requested = True - - if key in ["resample_min", "resample_minimum"]: - if "resolution" in params["data_preprocessing"][key]: - resize_requested = True - resolution_temp = np.array( - params["data_preprocessing"][key]["resolution"] - ) - if resolution_temp.size == 1: - temp_dict[key]["resolution"] = np.array( - [resolution_temp, resolution_temp] - ).tolist() - else: - temp_dict.pop(key) - - params["data_preprocessing"] = temp_dict - - if resize_requested and "resample" in params["data_preprocessing"]: - for key in ["resize", "resize_image", "resize_images", "resize_patch"]: - if key in params["data_preprocessing"]: - params["data_preprocessing"].pop(key) - - print( - "WARNING: Different 'resize' operations are ignored as 'resample' is defined under 'data_processing'", - file=sys.stderr, - ) - - # iterate through all keys - for key in params["data_preprocessing"]: # iterate through all keys - if key in thresholdOrClipDict: - # we only allow one of threshold or clip to occur and not both - assert not ( - thresholdOrClip - ), "Use only `threshold` or `clip`, not both" - thresholdOrClip = True - # initialize if nothing is present - if not (isinstance(params["data_preprocessing"][key], dict)): - params["data_preprocessing"][key] = {} - - # if one of the required parameters is not present, initialize with lowest/highest possible values - # this ensures the absence of a field doesn't affect processing - # for threshold or clip, ensure min and max are defined - if not "min" in params["data_preprocessing"][key]: - params["data_preprocessing"][key]["min"] = sys.float_info.min - if not "max" in params["data_preprocessing"][key]: - params["data_preprocessing"][key]["max"] = sys.float_info.max - - if key == "histogram_matching": - if params["data_preprocessing"][key] is not False: - if not (isinstance(params["data_preprocessing"][key], dict)): - params["data_preprocessing"][key] = {} - - if key == "histogram_equalization": - if params["data_preprocessing"][key] is not False: - # if histogram equalization is enabled, call histogram_matching - params["data_preprocessing"]["histogram_matching"] = {} - - if key == "adaptive_histogram_equalization": - if params["data_preprocessing"][key] is not False: - # if histogram equalization is enabled, call histogram_matching - params["data_preprocessing"]["histogram_matching"] = { - "target": "adaptive" - } - - # this is NOT a required parameter - a user should be able to train with NO built-in post-processing - params = initialize_key(params, "data_postprocessing", {}) - params = initialize_key( - params, "data_postprocessing_after_reverse_one_hot_encoding", {} - ) - temp_dict = deepcopy(params["data_postprocessing"]) - for key in temp_dict: - if key in postprocessing_after_reverse_one_hot_encoding: - params["data_postprocessing_after_reverse_one_hot_encoding"][key] = params[ - "data_postprocessing" - ][key] - params["data_postprocessing"].pop(key) - - if "model" in params: - assert isinstance( - params["model"], dict - ), "The 'model' parameter needs to be populated as a dictionary" - assert ( - len(params["model"]) > 0 - ), "The 'model' parameter needs to be populated as a dictionary and should have all properties present" - assert ( - "architecture" in params["model"] - ), "The 'model' parameter needs 'architecture' to be defined" - assert ( - "final_layer" in params["model"] - ), "The 'model' parameter needs 'final_layer' to be defined" - assert ( - "dimension" in params["model"] - ), "The 'model' parameter needs 'dimension' to be defined" - - if "amp" in params["model"]: - pass - else: - print("NOT using Mixed Precision Training") - params["model"]["amp"] = False - - if "norm_type" in params["model"]: - if ( - params["model"]["norm_type"] == None - or params["model"]["norm_type"].lower() == "none" - ): - if not ("vgg" in params["model"]["architecture"]): - raise ValueError( - "Normalization type cannot be 'None' for non-VGG architectures" - ) - else: - print("WARNING: Initializing 'norm_type' as 'batch'", flush=True) - params["model"]["norm_type"] = "batch" - - if not ("base_filters" in params["model"]): - base_filters = 32 - params["model"]["base_filters"] = base_filters - print("Using default 'base_filters' in 'model': ", base_filters) - if not ("class_list" in params["model"]): - params["model"]["class_list"] = [] # ensure that this is initialized - if not ("ignore_label_validation" in params["model"]): - params["model"]["ignore_label_validation"] = None - if "batch_norm" in params["model"]: - print( - "WARNING: 'batch_norm' is no longer supported, please use 'norm_type' in 'model' instead", - flush=True, - ) - params["model"]["print_summary"] = params["model"].get("print_summary", True) - - channel_keys_to_check = ["n_channels", "channels", "model_channels"] - for key in channel_keys_to_check: - if key in params["model"]: - params["model"]["num_channels"] = params["model"][key] - break - - # initialize model type for processing: if not defined, default to torch - if not ("type" in params["model"]): - params["model"]["type"] = "torch" - - # initialize openvino model data type for processing: if not defined, default to FP32 - if not ("data_type" in params["model"]): - params["model"]["data_type"] = "FP32" - - # set default save strategy for model - if not ("save_at_every_epoch" in params["model"]): - params["model"]["save_at_every_epoch"] = False - - if params["model"]["save_at_every_epoch"]: - print( - "WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk." - ) - - if isinstance(params["model"]["class_list"], str): - if ("||" in params["model"]["class_list"]) or ( - "&&" in params["model"]["class_list"] - ): - # special case for multi-class computation - this needs to be handled during one-hot encoding mask construction - print( - "WARNING: This is a special case for multi-class computation, where different labels are processed together, `reverse_one_hot` will need mapping information to work correctly" - ) - temp_classList = params["model"]["class_list"] - # we don't need the brackets - temp_classList = temp_classList.replace("[", "") - temp_classList = temp_classList.replace("]", "") - params["model"]["class_list"] = temp_classList.split(",") - else: - try: - params["model"]["class_list"] = eval(params["model"]["class_list"]) - except Exception as e: - ## todo: ensure logging captures assertion errors - assert ( - False - ), f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}" - # logging.error( - # f"Could not evaluate the `class_list` in `model`, Exception: {str(e)}, {traceback.format_exc()}" - # ) - - assert ( - "nested_training" in params - ), "The parameter 'nested_training' needs to be defined" - # initialize defaults for nested training - params["nested_training"]["stratified"] = params["nested_training"].get( - "stratified", False - ) - params["nested_training"]["stratified"] = params["nested_training"].get( - "proportional", params["nested_training"]["stratified"] - ) - params["nested_training"]["testing"] = params["nested_training"].get("testing", -5) - params["nested_training"]["validation"] = params["nested_training"].get( - "validation", -5 - ) - - parallel_compute_command = "" - if "parallel_compute_command" in params: - parallel_compute_command = params["parallel_compute_command"] - parallel_compute_command = parallel_compute_command.replace("'", "") - parallel_compute_command = parallel_compute_command.replace('"', "") - params["parallel_compute_command"] = parallel_compute_command - - if "opt" in params: - print("DeprecationWarning: 'opt' has been superseded by 'optimizer'") - params["optimizer"] = params["opt"] - - # initialize defaults for patch sampler - temp_patch_sampler_dict = { - "type": "uniform", - "enable_padding": False, - "padding_mode": "symmetric", - "biased_sampling": False, - } - # check if patch_sampler is defined in the config - if "patch_sampler" in params: - # if "patch_sampler" is a string, then it is the type of sampler - if isinstance(params["patch_sampler"], str): - print( - "WARNING: Defining 'patch_sampler' as a string will be deprecated in a future release, please use a dictionary instead" - ) - temp_patch_sampler_dict["type"] = params["patch_sampler"].lower() - elif isinstance(params["patch_sampler"], dict): - # dict requires special handling - for key in params["patch_sampler"]: - temp_patch_sampler_dict[key] = params["patch_sampler"][key] - - # now assign the dict back to the params - params["patch_sampler"] = temp_patch_sampler_dict - del temp_patch_sampler_dict - - # define defaults - for current_parameter in parameter_defaults: - params = initialize_parameter( - params, current_parameter, parameter_defaults[current_parameter], True - ) - - for current_parameter in parameter_defaults_string: - params = initialize_parameter( - params, - current_parameter, - parameter_defaults_string[current_parameter], - False, - ) - - # ensure that the scheduler and optimizer are dicts - if isinstance(params["scheduler"], str): - temp_dict = {} - temp_dict["type"] = params["scheduler"] - params["scheduler"] = temp_dict - - if not ("step_size" in params["scheduler"]): - params["scheduler"]["step_size"] = params["learning_rate"] / 5.0 - print( - "WARNING: Setting default step_size to:", params["scheduler"]["step_size"] - ) - - # initialize default optimizer - params["optimizer"] = params.get("optimizer", {}) - if isinstance(params["optimizer"], str): - temp_dict = {} - temp_dict["type"] = params["optimizer"] - params["optimizer"] = temp_dict - - # initialize defaults for DP - if params.get("differential_privacy"): - params = parse_opacus_params(params, initialize_key) - - # initialize defaults for inference mechanism - inference_mechanism = {"grid_aggregator_overlap": "crop", "patch_overlap": 0} - initialize_inference_mechanism = False - if not ("inference_mechanism" in params): - initialize_inference_mechanism = True - elif not (isinstance(params["inference_mechanism"], dict)): - initialize_inference_mechanism = True - else: - for key in inference_mechanism: - if not (key in params["inference_mechanism"]): - params["inference_mechanism"][key] = inference_mechanism[key] - - if initialize_inference_mechanism: - params["inference_mechanism"] = inference_mechanism - return params def ConfigManager( config_file_path: Union[str, dict], version_check_flag: bool = True -) -> None: +) -> dict: """ This function parses the configuration file and returns a dictionary of parameters. @@ -747,7 +43,17 @@ def ConfigManager( dict: The parameter dictionary. """ try: - return _parseConfig(config_file_path, version_check_flag) + parameters_config = Parameters( + **_parseConfig(config_file_path, version_check_flag) + ) + parameters = parameters_config.model_dump( + exclude={ + field + for field in exclude_parameters + if getattr(parameters_config, field) is None + } + ) + return parameters except Exception as e: ## todo: ensure logging captures assertion errors assert ( diff --git a/setup.py b/setup.py index fdb5d7c6f..d7ccf4876 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,8 @@ import sys, re, os + + from setuptools import setup, find_packages @@ -85,6 +87,7 @@ "openslide-bin", "openslide-python==1.4.1", "lion-pytorch==0.2.2", + "pydantic", ] if __name__ == "__main__": diff --git a/testing/test_full.py b/testing/test_full.py index eccf0b3c8..d25268df1 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import logging +import json from pydicom.data import get_testdata_file import cv2 @@ -988,8 +989,8 @@ def test_train_scheduler_classification_rad_2d(device): parameters = populate_header_in_parameters(parameters, parameters["headers"]) parameters["model"]["onnx_export"] = False parameters["model"]["print_summary"] = False - parameters["scheduler"] = {} - parameters["scheduler"]["type"] = scheduler + parameters["scheduler"] = scheduler + # parameters["scheduler"]["type"] = scheduler parameters["nested_training"]["testing"] = -5 parameters["nested_training"]["validation"] = -5 sanitize_outputDir() @@ -3362,6 +3363,8 @@ def test_differential_privacy_epsilon_classification_rad_2d(device): yaml.dump(parameters, file) parameters = parseConfig(file_config_temp, version_check_flag=True) + print(json.dumps(parameters)) + TrainingManager( dataframe=training_data, outputDir=outputDir,