diff --git a/GANDLF/configuration/scheduler_config.py b/GANDLF/configuration/scheduler_config.py index ed1e19e41..1990a90b6 100644 --- a/GANDLF/configuration/scheduler_config.py +++ b/GANDLF/configuration/scheduler_config.py @@ -1,17 +1,91 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal - +from typing import Union from GANDLF.schedulers import global_schedulers_dict TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())] +class base_triangle_config(BaseModel): + min_lr: float = Field(default=(10**-3)) + max_lr: float = Field(default=1) + step_size: float = Field(description="step_size", default=None) + + +class triangle_modified_config(BaseModel): + min_lr: float = Field(default=0.000001) + max_lr: float = Field(default=0.001) + max_lr_multiplier: float = Field(default=1.0) + step_size: float = Field(description="step_size", default=None) + + +class cyclic_lr_base_config(BaseModel): + # More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html + min_lr: float = Field( + default=None + ) # The default value is calculated according the learning rate * 0.001 + max_lr: float = Field(default=None) # calculate in the validation stage + gamma: float = Field(default=0.1) + scale_mode: Literal["cycle", "iterations"] = Field(default="cycle") + cycle_momentum: bool = Field(default=False) + base_momentum: float = Field(default=0.8) + max_momentum: float = Field(default=0.9) + step_size: float = Field(description="step_size", default=None) + + +class exp_config(BaseModel): + gamma: float = Field(default=0.1) + + +class step_config(BaseModel): + gamma: float = Field(default=0.1) + step_size: float = Field(description="step_size", default=None) + + +class cosineannealing_config(BaseModel): + # More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html + T_0: int = Field(default=5) + T_mult: float = Field(default=1) + min_lr: float = Field(default=0.001) + + +class reduce_on_plateau_config(BaseModel): + # More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html + min_lr: Union[float, list] = Field(default=None) + gamma: float = Field(default=0.1) + mode: Literal["min", "max"] = Field(default="min") + factor: float = Field(default=0.1) + patience: int = Field(default=10) + threshold: float = Field(default=0.0001) + cooldown: int = Field(default=0) + threshold_mode: Literal["rel", "abs"] = Field(default="rel") + + +class warmupcosineschedule_config(BaseModel): + # More details https://docs.monai.io/en/stable/optimizers.html#monai.optimizers.WarmupCosineSchedule + warmup_steps: int = Field(default=None) + + # It allows extra parameters class SchedulerConfig(BaseModel): model_config = ConfigDict(extra="allow") - type: TYPE_OPTIONS = Field( - description="triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR" - ) - # min_lr: 0.00001, #TODO: this should be defined ?? - # max_lr: 1, #TODO: this should be defined ?? - step_size: float = Field(description="step_size", default=None) + type: TYPE_OPTIONS = Field(description="scheduler type") + + +# Define the type and the scheduler base model class +schedulers_dict_config = { + "triangle": base_triangle_config, + "triangle_modified": triangle_modified_config, + "triangular": cyclic_lr_base_config, + "exp_range": cyclic_lr_base_config, + "exp": exp_config, + "exponential": exp_config, + "step": step_config, + "reduce_on_plateau": reduce_on_plateau_config, + "reduce-on-plateau": reduce_on_plateau_config, + "plateau": reduce_on_plateau_config, + "reduceonplateau": reduce_on_plateau_config, + "cosineannealing": cosineannealing_config, + "warmupcosineschedule": warmupcosineschedule_config, + "wcs": warmupcosineschedule_config, +} diff --git a/GANDLF/configuration/user_defined_config.py b/GANDLF/configuration/user_defined_config.py index e44b6c1ef..02229c598 100644 --- a/GANDLF/configuration/user_defined_config.py +++ b/GANDLF/configuration/user_defined_config.py @@ -103,7 +103,9 @@ def validate(self) -> Self: self.parallel_compute_command ) # validate scheduler - self.scheduler = validate_schedular(self.scheduler, self.learning_rate) + self.scheduler = validate_schedular( + self.scheduler, self.learning_rate, self.num_epochs + ) # validate optimizer self.optimizer = validate_optimizer(self.optimizer) # validate patch_sampler diff --git a/GANDLF/configuration/utils.py b/GANDLF/configuration/utils.py index 850fb9c33..f8dc09cad 100644 --- a/GANDLF/configuration/utils.py +++ b/GANDLF/configuration/utils.py @@ -3,7 +3,7 @@ from typing import Type -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError, create_model from pydantic_core import ErrorDetails @@ -107,3 +107,24 @@ def handle_configuration_errors(e: ValidationError): messages = extract_messages(convert_errors(e)) for message in messages: logging.error(message) + + +def combine_models(base_model: Type[BaseModel], extra_model: Type[BaseModel]): + """Combine base model with an extra model dynamically.""" + fields = {} + # Collect base model fields + for field_name, field_info in base_model.model_fields.items(): + fields[field_name] = ( + field_info.annotation, + field_info.default if field_info.default is not Ellipsis else ..., + ) + + # Add fields from the extra model + for field_name, field_info in extra_model.model_fields.items(): + fields[field_name] = ( + field_info.annotation, + field_info.default if field_info.default is not Ellipsis else ..., + ) + + # Return the new dynamically combined model + return create_model(base_model.__name__, **fields) diff --git a/GANDLF/configuration/validators.py b/GANDLF/configuration/validators.py index fff0f1539..ae218f8eb 100644 --- a/GANDLF/configuration/validators.py +++ b/GANDLF/configuration/validators.py @@ -8,8 +8,12 @@ import sys from GANDLF.configuration.optimizer_config import OptimizerConfig from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig -from GANDLF.configuration.scheduler_config import SchedulerConfig -from GANDLF.configuration.utils import initialize_key +from GANDLF.configuration.scheduler_config import ( + SchedulerConfig, + base_triangle_config, + schedulers_dict_config, +) +from GANDLF.configuration.utils import initialize_key, combine_models from GANDLF.metrics import surface_distance_ids @@ -169,11 +173,26 @@ def validate_parallel_compute_command(value): return value -def validate_schedular(value, learning_rate): +def validate_schedular(value, learning_rate, num_epochs): if isinstance(value, str): value = SchedulerConfig(type=value) - if value.step_size is None: + # Find the scheduler_config class based on the type + combine_scheduler_class = schedulers_dict_config[value.type] + # Combine it with the SchedulerConfig class + schedulerConfigCombine = combine_models(SchedulerConfig, combine_scheduler_class) + combineScheduler = schedulerConfigCombine(**value.model_dump()) + value = SchedulerConfig(**combineScheduler.model_dump()) + + if value.type == "triangular": + if value.min_lr is None: + value.min_lr = learning_rate * 0.001 + if value.max_lr is None: + value.max_lr = learning_rate + if value.type in ["warmupcosineschedule", "wcs"]: + value.warmup_steps = num_epochs * 0.1 + if hasattr(value, "step_size") and value.step_size is None: value.step_size = learning_rate / 5.0 + return value