Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
51a5b52
init the pydantic configuration
benmalef Jan 30, 2025
2458e5e
add default_parameters
benmalef Jan 30, 2025
df290e3
add utils and generate_and_save_markdown function
benmalef Jan 30, 2025
66bf03d
add user_defined_parameters.py
benmalef Jan 30, 2025
9539f06
blacked .
benmalef Jan 30, 2025
cab38d7
blacked .
benmalef Jan 30, 2025
0e299db
define user_defined_parameters.py
benmalef Feb 1, 2025
576577d
create parameters(BaseModel)
benmalef Feb 1, 2025
5b0ae98
made some code changes in generate_documentation.py
benmalef Feb 1, 2025
10d5f78
add pydantic in setup.py
benmalef Feb 1, 2025
207e4c0
comment out config_manager
benmalef Feb 1, 2025
39c5b1b
add a temp test dir for testing purposes
benmalef Feb 1, 2025
adc2ebc
fix spelling errors
benmalef Feb 1, 2025
24de26a
blacked .
benmalef Feb 1, 2025
2ddad48
refactor user_defined_parameters.py
benmalef Feb 3, 2025
2b2fc03
add patch_size parameter and validation
benmalef Feb 3, 2025
8c8a296
add parameter <modality>
benmalef Feb 4, 2025
99f1158
add parameter loss_function
benmalef Feb 4, 2025
5c2731a
refactor the code
benmalef Feb 4, 2025
f208db3
add model parameters and refactor the code
benmalef Feb 4, 2025
7dc0f3f
update model_parameters.py
benmalef Feb 6, 2025
d3b33f8
update model_parameters.py
benmalef Feb 8, 2025
61803aa
add nestedTraining
benmalef Feb 8, 2025
5ce1f28
update the validators.py
benmalef Feb 8, 2025
6df3cf1
create scheduler_parameters.py
benmalef Feb 8, 2025
3bde096
create nested_training_parameters.py
benmalef Feb 8, 2025
ed5f7f5
update the test_configuration.py
benmalef Feb 8, 2025
09c6694
fix scheduler step_size
benmalef Feb 8, 2025
e0d6d0b
change the configuration structure
benmalef Feb 8, 2025
6ef1701
change the location of validators file
benmalef Feb 8, 2025
7e2462b
update the test_configuration.py
benmalef Feb 8, 2025
aedc67e
update the configuration with patch_sampler.py
benmalef Feb 8, 2025
79fc32d
update the configuration
benmalef Feb 9, 2025
a35f399
blacked .
benmalef Feb 9, 2025
8fc7724
see the full test
benmalef Feb 9, 2025
0f11356
delete test_configuration.py
benmalef Feb 9, 2025
a9162f7
clean the configuration
benmalef Feb 9, 2025
2cf78f3
black
benmalef Feb 9, 2025
9713f2e
delete test config_all_options.yaml
benmalef Feb 9, 2025
abaa1a8
fix version bug
benmalef Feb 9, 2025
dfac018
fix the num_channels bug in the tests
benmalef Feb 9, 2025
a0622b6
change the metrics type
benmalef Feb 9, 2025
af5e1e1
update validators and user_defined_parameters.py
benmalef Feb 9, 2025
7db6dc5
blacked .
benmalef Feb 9, 2025
aa91f05
updated model architecture with "vgg16"
benmalef Feb 9, 2025
1c99606
blacked .
benmalef Feb 9, 2025
1ec5477
update the model and the nested_training
benmalef Feb 10, 2025
a30b986
update the model and the user_defined_parameters
benmalef Feb 10, 2025
e6714f1
minor changes
benmalef Feb 10, 2025
f3ccfea
update model_parameters
benmalef Feb 10, 2025
d19bf5a
update validators and user_defined_parameters with data_postprocessin…
benmalef Feb 10, 2025
b054a5f
update user_defined_parameters.py
benmalef Feb 10, 2025
27542d8
update user_defined_parameters.py
benmalef Feb 10, 2025
0afd0e5
update user_defined_parameters.py
benmalef Feb 10, 2025
afe7fd6
update user_defined_parameters.py
benmalef Feb 10, 2025
ef55332
fix spelling error
benmalef Feb 10, 2025
02dc310
update the scheduler
benmalef Feb 10, 2025
55b32c8
update the scheduler
benmalef Feb 10, 2025
cb70034
update user_defined_parameters
benmalef Feb 10, 2025
c7f7f46
change scheduler_classification_rad_2d
benmalef Feb 10, 2025
c02dbe1
update scheduler_parameters.py
benmalef Feb 10, 2025
afca656
update scheduler_parameters.py
benmalef Feb 10, 2025
2fd32ba
added differential_privacy in parameters
benmalef Feb 10, 2025
6003152
added differential_privacy in parameters
benmalef Feb 10, 2025
cf7b351
remove batch_norm
benmalef Feb 10, 2025
007a967
remove batch_norm
benmalef Feb 10, 2025
61acfdc
update configuration
benmalef Feb 10, 2025
60aa3d5
add exclude parameters
benmalef Feb 10, 2025
6443161
add exclude parameters
benmalef Feb 10, 2025
9953aac
update scheduler_parameters.py
benmalef Feb 11, 2025
c4e0579
update literals
benmalef Feb 11, 2025
af437b9
update literals
benmalef Feb 11, 2025
caa4dca
update the differential_privacy
benmalef Feb 11, 2025
a0aeb43
update the differential_privacy
benmalef Feb 11, 2025
3d63edc
update validators
benmalef Feb 11, 2025
612d2e0
change the workflow
benmalef Feb 11, 2025
d7bef63
update the validators.py
benmalef Feb 11, 2025
dcea6ab
change the test_full
benmalef Feb 11, 2025
d08d8ee
change workflow
benmalef Feb 11, 2025
3e902c2
update validators.py
benmalef Feb 11, 2025
fc2ed69
update pydantic configuration
benmalef Feb 11, 2025
3ca341a
update pydantic configuration
benmalef Feb 11, 2025
6606de4
fix norm_type error
benmalef Feb 11, 2025
f624516
update test_full
benmalef Feb 11, 2025
afd63de
update test_full
benmalef Feb 11, 2025
dae6c1e
update config_manager.py
benmalef Feb 11, 2025
17c1d6d
update config_manager.py and test_full.py
benmalef Feb 12, 2025
c7c696f
revert python_tests in GitHub workflows
benmalef Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
67 changes: 67 additions & 0 deletions GANDLF/Configuration/Parameters/default_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pydantic import BaseModel, Field
from typing import Optional, Dict, Set

from typing_extensions import Union


class DefaultParameters(BaseModel):
weighted_loss: bool = Field(
default=False, description="Whether weighted loss is to be used or not."
)
verbose: bool = Field(default=False, description="General application verbosity.")
q_verbose: bool = Field(default=False, description="Queue construction verbosity.")
medcam_enabled: bool = Field(
default=False, description="Enable interpretability via medcam."
)
save_training: bool = Field(
default=False, description="Save outputs during training."
)
save_output: bool = Field(
default=False, description="Save outputs during validation/testing."
)
in_memory: bool = Field(default=False, description="Pin data to CPU memory.")
pin_memory_dataloader: bool = Field(
default=False, description="Pin data to GPU memory."
)
scaling_factor: int = Field(
default=1, description="Scaling factor for regression problems."
)
q_max_length: int = Field(default=100, description="The max length of the queue.")
q_samples_per_volume: int = Field(
default=10, description="Number of samples per volume."
)
q_num_workers: int = Field(
default=4, description="Number of worker threads to use."
)
num_epochs: int = Field(default=100, description="Total number of epochs to train.")
patience: int = Field(
default=100, description="Number of epochs to wait for performance improvement."
)
batch_size: int = Field(default=1, description="Default batch size for training.")
learning_rate: float = Field(default=0.001, description="Default learning rate.")
clip_grad: Optional[float] = Field(
default=None, description="Gradient clipping value."
)
track_memory_usage: bool = Field(
default=False, description="Enable memory usage tracking."
)
memory_save_mode: bool = Field(
default=False,
description="Enable memory-saving mode. If enabled, resize/resample will save files to disk.",
)
print_rgb_label_warning: bool = Field(
default=True, description="Print a warning for RGB labels."
)
data_postprocessing: Union[dict, set] = Field(
default={}, description="Default data postprocessing configuration."
)
grid_aggregator_overlap: str = Field(
default="crop", description="Default grid aggregator overlap strategy."
)
determinism: bool = Field(
default=False, description="Enable deterministic computation."
)
previous_parameters: Optional[Dict] = Field(
default=None,
description="Previous parameters to be used for resuming training and performing sanity checks.",
)
1 change: 1 addition & 0 deletions GANDLF/Configuration/Parameters/exclude_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exclude_parameters = {"differential_privacy"}
69 changes: 69 additions & 0 deletions GANDLF/Configuration/Parameters/model_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from pydantic import BaseModel, model_validator, Field, AliasChoices, ConfigDict

from typing_extensions import Self, Literal, Optional
from typing import Union
from GANDLF.Configuration.Parameters.validators import (
validate_class_list,
validate_norm_type,
)
from GANDLF.models import global_models_dict

# Define model architecture options
ARCHITECTURE_OPTIONS = Literal[tuple(global_models_dict.keys())]
NORM_TYPE_OPTIONS = Literal["batch", "instance", "none"]


# You can define new parameters for model here. Please read the pydantic documentation.
# It allows extra fields in model dict.
class Model(BaseModel):
model_config = ConfigDict(
extra="allow"
) # it allows extra fields in the model dict
dimension: Optional[int] = Field(description="Dimension.")
architecture: Union[ARCHITECTURE_OPTIONS, dict] = Field(description="Architecture.")
final_layer: str = Field(description="Final layer.")
norm_type: Optional[NORM_TYPE_OPTIONS] = Field(
description="Normalization type.", default="batch"
) # TODO: check it again
base_filters: Optional[int] = Field(
description="Base filters.", default=None, validate_default=True
) # default is 32
class_list: Union[list, str] = Field(default=[], description="Class list.")
num_channels: Optional[int] = Field(
description="Number of channels.",
validation_alias=AliasChoices(
"num_channels", "n_channels", "channels", "model_channels"
),
default=3,
) # TODO: check it
type: Optional[str] = Field(description="Type of model.", default="torch")
data_type: str = Field(description="Data type.", default="FP32")
save_at_every_epoch: bool = Field(default=False, description="Save at every epoch.")
amp: bool = Field(default=False, description="Amplifier.")
ignore_label_validation: Union[int, None] = Field(
default=None, description="Ignore label validation."
) # TODO: To check it
print_summary: bool = Field(default=True, description="Print summary.")

@model_validator(mode="after")
def model_validate(self) -> Self:
# TODO: Change the print to logging.warnings
self.class_list = validate_class_list(
self.class_list
) # init and validate the class_list parameter
self.norm_type = validate_norm_type(
self.norm_type, self.architecture
) # init and validate the norm type
if self.amp is False:
print("NOT using Mixed Precision Training")

if self.save_at_every_epoch:
print(
"WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk."
) # TODO: It is better to use logging.warning

if self.base_filters is None:
self.base_filters = 32
print("Using default 'base_filters' in 'model': ", self.base_filters)

return self
24 changes: 24 additions & 0 deletions GANDLF/Configuration/Parameters/nested_training_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self, Optional


class NestedTraining(BaseModel):
stratified: bool = Field(
default=False,
description="this will perform stratified k-fold cross-validation but only with offline data splitting",
)
testing: int = Field(
default=-5,
description="this controls the number of testing data folds for final model evaluation; [NOT recommended] to disable this, use '1'",
)
validation: int = Field(
default=-5,
description="this controls the number of validation data folds to be used for model *selection* during training (not used for back-propagation)",
)
proportional: Optional[bool] = Field(default=None)

@model_validator(mode="after")
def validate_nested_training(self) -> Self:
if self.proportional is not None:
self.stratified = self.proportional
return self
10 changes: 10 additions & 0 deletions GANDLF/Configuration/Parameters/optimizer_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel, Field
from typing_extensions import Literal
from GANDLF.optimizers import global_optimizer_dict

# takes the keys from global optimizer
OPTIMIZER_OPTIONS = Literal[tuple(global_optimizer_dict.keys())]


class Optimizer(BaseModel):
type: OPTIMIZER_OPTIONS = Field(description="Type of optimizer to use")
12 changes: 12 additions & 0 deletions GANDLF/Configuration/Parameters/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic import BaseModel, ConfigDict
from GANDLF.Configuration.Parameters.user_defined_parameters import (
UserDefinedParameters,
)


class ParametersConfiguration(BaseModel):
model_config = ConfigDict(extra="allow")


class Parameters(ParametersConfiguration, UserDefinedParameters):
pass
8 changes: 8 additions & 0 deletions GANDLF/Configuration/Parameters/patch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel, Field


class PatchSampler(BaseModel):
type: str = Field(default="uniform")
enable_padding: bool = Field(default=False)
padding_mode: str = Field(default="symmetric")
biased_sampling: bool = Field(default=False)
17 changes: 17 additions & 0 deletions GANDLF/Configuration/Parameters/scheduler_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Literal

from GANDLF.schedulers import global_schedulers_dict

TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())]


# It allows extra parameters
class Scheduler(BaseModel):
model_config = ConfigDict(extra="allow")
type: TYPE_OPTIONS = Field(
description="triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR"
)
# min_lr: 0.00001, #TODO: this should be defined ??
# max_lr: 1, #TODO: this should be defined ??
step_size: float = Field(description="step_size", default=None)
111 changes: 111 additions & 0 deletions GANDLF/Configuration/Parameters/user_defined_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Union
from pydantic import BaseModel, model_validator, Field, AfterValidator
from GANDLF.Configuration.Parameters.default_parameters import DefaultParameters
from GANDLF.Configuration.Parameters.nested_training_parameters import NestedTraining
from GANDLF.Configuration.Parameters.patch_sampler import PatchSampler
from GANDLF.utils import version_check
from importlib.metadata import version
from typing_extensions import Self, Literal, Annotated, Any
from GANDLF.Configuration.Parameters.validators import *
from GANDLF.Configuration.Parameters.model_parameters import Model


class Version(BaseModel): # TODO: Maybe should be to another folder
minimum: str
maximum: str

@model_validator(mode="after")
def validate_version(self) -> Self:
if version_check(self.model_dump(), version_to_check=version("GANDLF")):
return self


class InferenceMechanism(BaseModel):
grid_aggregator_overlap: Literal["crop", "average"] = Field(default="crop")
patch_overlap: int = Field(default=0)


class UserDefinedParameters(DefaultParameters):
version: Version = Field(
default=Version(minimum=version("GANDLF"), maximum=version("GANDLF")),
description="Whether weighted loss is to be used or not.",
)
patch_size: Union[list[Union[int, float]], int, float] = Field(
description="Patch size."
)
model: Model = Field(..., description="The model to use. ")
modality: Literal["rad", "histo", "path"] = Field(description="Modality.")
loss_function: Annotated[
Union[dict, str],
Field(description="Loss function."),
AfterValidator(validate_loss_function),
]
metrics: Annotated[
Union[dict, list[Union[str, dict, set]]],
Field(description="Metrics."),
AfterValidator(validate_metrics),
]
nested_training: NestedTraining = Field(description="Nested training.")
parallel_compute_command: str = Field(
default="", description="Parallel compute command."
)
scheduler: Union[str, Scheduler] = Field(
description="Scheduler.", default=Scheduler(type="triangle_modified")
)
optimizer: Union[str, Optimizer] = Field(
description="Optimizer.", default=Optimizer(type="adam")
) # TODO: Check it again for (opt)
patch_sampler: Union[str, PatchSampler] = Field(
description="Patch sampler.", default=PatchSampler()
)
inference_mechanism: InferenceMechanism = Field(
description="Inference mechanism.", default=InferenceMechanism()
)
data_postprocessing_after_reverse_one_hot_encoding: dict = Field(
description="data_postprocessing_after_reverse_one_hot_encoding.", default={}
)
differential_privacy: Any = Field(description="Differential privacy.", default=None)
# TODO: It should be defined with a better way (using a BaseModel class)
data_preprocessing: Annotated[
dict,
Field(description="Data preprocessing."),
AfterValidator(validate_data_preprocessing),
] = {}
# TODO: It should be defined with a better way (using a BaseModel class)
data_augmentation: Annotated[dict, Field(description="Data augmentation.")] = {}

# Validators
@model_validator(mode="after")
def validate(self) -> Self:
# validate the patch_size
self.patch_size, self.model.dimension = validate_patch_size(
self.patch_size, self.model.dimension
)
# validate the parallel_compute_command
self.parallel_compute_command = validate_parallel_compute_command(
self.parallel_compute_command
)
# validate scheduler
self.scheduler = validate_schedular(self.scheduler, self.learning_rate)
# validate optimizer
self.optimizer = validate_optimizer(self.optimizer)
# validate patch_sampler
self.patch_sampler = validate_patch_sampler(self.patch_sampler)
# validate_data_augmentation
self.data_augmentation = validate_data_augmentation(
self.data_augmentation, self.patch_size
)
# validate data_postprocessing_after_reverse_one_hot_encoding
(
self.data_postprocessing_after_reverse_one_hot_encoding,
self.data_postprocessing,
) = validate_data_postprocessing_after_reverse_one_hot_encoding(
self.data_postprocessing_after_reverse_one_hot_encoding,
self.data_postprocessing,
)
# validate differential_privacy
self.differential_privacy = validate_differential_privacy(
self.differential_privacy, self.batch_size
)

return self
Loading
Loading