Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 81 additions & 7 deletions GANDLF/configuration/scheduler_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,91 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Literal

from typing import Union
from GANDLF.schedulers import global_schedulers_dict

TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())]


class base_triangle_config(BaseModel):
min_lr: float = Field(default=(10**-3))
max_lr: float = Field(default=1)
step_size: float = Field(description="step_size", default=None)


class triangle_modified_config(BaseModel):
min_lr: float = Field(default=0.000001)
max_lr: float = Field(default=0.001)
max_lr_multiplier: float = Field(default=1.0)
step_size: float = Field(description="step_size", default=None)


class cyclic_lr_base_config(BaseModel):
# More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html
min_lr: float = Field(
default=None
) # The default value is calculated according the learning rate * 0.001
max_lr: float = Field(default=None) # calculate in the validation stage
gamma: float = Field(default=0.1)
scale_mode: Literal["cycle", "iterations"] = Field(default="cycle")
cycle_momentum: bool = Field(default=False)
base_momentum: float = Field(default=0.8)
max_momentum: float = Field(default=0.9)
step_size: float = Field(description="step_size", default=None)


class exp_config(BaseModel):
gamma: float = Field(default=0.1)


class step_config(BaseModel):
gamma: float = Field(default=0.1)
step_size: float = Field(description="step_size", default=None)


class cosineannealing_config(BaseModel):
# More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html
T_0: int = Field(default=5)
T_mult: float = Field(default=1)
min_lr: float = Field(default=0.001)


class reduce_on_plateau_config(BaseModel):
# More details https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
min_lr: Union[float, list] = Field(default=None)
gamma: float = Field(default=0.1)
mode: Literal["min", "max"] = Field(default="min")
factor: float = Field(default=0.1)
patience: int = Field(default=10)
threshold: float = Field(default=0.0001)
cooldown: int = Field(default=0)
threshold_mode: Literal["rel", "abs"] = Field(default="rel")


class warmupcosineschedule_config(BaseModel):
# More details https://docs.monai.io/en/stable/optimizers.html#monai.optimizers.WarmupCosineSchedule
warmup_steps: int = Field(default=None)


# It allows extra parameters
class SchedulerConfig(BaseModel):
model_config = ConfigDict(extra="allow")
type: TYPE_OPTIONS = Field(
description="triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR"
)
# min_lr: 0.00001, #TODO: this should be defined ??
# max_lr: 1, #TODO: this should be defined ??
step_size: float = Field(description="step_size", default=None)
type: TYPE_OPTIONS = Field(description="scheduler type")


# Define the type and the scheduler base model class
schedulers_dict_config = {
"triangle": base_triangle_config,
"triangle_modified": triangle_modified_config,
"triangular": cyclic_lr_base_config,
"exp_range": cyclic_lr_base_config,
"exp": exp_config,
"exponential": exp_config,
"step": step_config,
"reduce_on_plateau": reduce_on_plateau_config,
"reduce-on-plateau": reduce_on_plateau_config,
"plateau": reduce_on_plateau_config,
"reduceonplateau": reduce_on_plateau_config,
"cosineannealing": cosineannealing_config,
"warmupcosineschedule": warmupcosineschedule_config,
"wcs": warmupcosineschedule_config,
}
4 changes: 3 additions & 1 deletion GANDLF/configuration/user_defined_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def validate(self) -> Self:
self.parallel_compute_command
)
# validate scheduler
self.scheduler = validate_schedular(self.scheduler, self.learning_rate)
self.scheduler = validate_schedular(
self.scheduler, self.learning_rate, self.num_epochs
)
# validate optimizer
self.optimizer = validate_optimizer(self.optimizer)
# validate patch_sampler
Expand Down
23 changes: 22 additions & 1 deletion GANDLF/configuration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


from typing import Type
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ValidationError, create_model
from pydantic_core import ErrorDetails


Expand Down Expand Up @@ -107,3 +107,24 @@ def handle_configuration_errors(e: ValidationError):
messages = extract_messages(convert_errors(e))
for message in messages:
logging.error(message)


def combine_models(base_model: Type[BaseModel], extra_model: Type[BaseModel]):
"""Combine base model with an extra model dynamically."""
fields = {}
# Collect base model fields
for field_name, field_info in base_model.model_fields.items():
fields[field_name] = (
field_info.annotation,
field_info.default if field_info.default is not Ellipsis else ...,
)

# Add fields from the extra model
for field_name, field_info in extra_model.model_fields.items():
fields[field_name] = (
field_info.annotation,
field_info.default if field_info.default is not Ellipsis else ...,
)

# Return the new dynamically combined model
return create_model(base_model.__name__, **fields)
27 changes: 23 additions & 4 deletions GANDLF/configuration/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
import sys
from GANDLF.configuration.optimizer_config import OptimizerConfig
from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
from GANDLF.configuration.scheduler_config import SchedulerConfig
from GANDLF.configuration.utils import initialize_key
from GANDLF.configuration.scheduler_config import (
SchedulerConfig,
base_triangle_config,
schedulers_dict_config,
)
from GANDLF.configuration.utils import initialize_key, combine_models
from GANDLF.metrics import surface_distance_ids


Expand Down Expand Up @@ -169,11 +173,26 @@ def validate_parallel_compute_command(value):
return value


def validate_schedular(value, learning_rate):
def validate_schedular(value, learning_rate, num_epochs):
if isinstance(value, str):
value = SchedulerConfig(type=value)
if value.step_size is None:
# Find the scheduler_config class based on the type
combine_scheduler_class = schedulers_dict_config[value.type]
# Combine it with the SchedulerConfig class
schedulerConfigCombine = combine_models(SchedulerConfig, combine_scheduler_class)
combineScheduler = schedulerConfigCombine(**value.model_dump())
value = SchedulerConfig(**combineScheduler.model_dump())

if value.type == "triangular":
if value.min_lr is None:
value.min_lr = learning_rate * 0.001
if value.max_lr is None:
value.max_lr = learning_rate
if value.type in ["warmupcosineschedule", "wcs"]:
value.warmup_steps = num_epochs * 0.1
if hasattr(value, "step_size") and value.step_size is None:
value.step_size = learning_rate / 5.0

return value


Expand Down
Loading