Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,11 @@ def check_input_channels(self, x):
return False
return True

def forward(self, x, return_mae_mask=False):
def forward(self, x, return_mae_mask=False, apply_activation=None):
# Check input channels and warn if mismatch
self.check_input_channels(x)
if apply_activation is None:
apply_activation = not self.training

# Get features from encoder (works for both U-Net and Primus)
# For MAE training with Primus, we need to get the mask
Expand Down Expand Up @@ -897,7 +899,7 @@ def forward(self, x, return_mae_mask=False):
logits = logits[0]
logits = self._apply_z_projection(task_name, logits)
activation_fn = self.task_activations[task_name] if task_name in self.task_activations else None
if activation_fn is not None and not self.training:
if activation_fn is not None and apply_activation:
if isinstance(logits, (list, tuple)):
logits = type(logits)(activation_fn(l) for l in logits)
else:
Expand All @@ -912,7 +914,7 @@ def forward(self, x, return_mae_mask=False):
logits = head(shared_features)
logits = self._apply_z_projection(task_name, logits)
activation_fn = self.task_activations[task_name] if task_name in self.task_activations else None
if activation_fn is not None and not self.training:
if activation_fn is not None and apply_activation:
if isinstance(logits, (list, tuple)):
logits = type(logits)(activation_fn(l) for l in logits)
else:
Expand Down
10 changes: 10 additions & 0 deletions vesuvius/src/vesuvius/models/configuration/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ def _init_attributes(self):

# Chunk-slicing worker configuration
self.valid_patch_find_resolution = int(self.dataset_config.get("valid_patch_find_resolution", 1))
self.ome_zarr_resolution = int(self.dataset_config.get("ome_zarr_resolution", 0))
if self.ome_zarr_resolution < 0:
raise ValueError(
f"dataset_config.ome_zarr_resolution must be >= 0, got {self.ome_zarr_resolution}"
)
if self.valid_patch_find_resolution < self.ome_zarr_resolution:
raise ValueError(
"dataset_config.valid_patch_find_resolution must be >= dataset_config.ome_zarr_resolution "
f"(got {self.valid_patch_find_resolution} < {self.ome_zarr_resolution})"
)
self.num_workers = int(self.dataset_config.get("num_workers", 8))

# Worker configuration for image→Zarr pipeline
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s0_ps128_bs28_bcedice"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [128, 128, 128]
batch_size: 28
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 0
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 1
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "BinaryBCEAndDiceLoss"
weight: 1.0
weight_bce: 1.0
weight_dice: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s0_ps128_bs28_msr"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [128, 128, 128]
batch_size: 28
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 0
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 2
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "MedialSurfaceRecall"
weight: 1.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s0_ps256_bs3_bcedice"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [256, 256, 256]
batch_size: 3
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 0
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 1
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "BinaryBCEAndDiceLoss"
weight: 1.0
weight_bce: 1.0
weight_dice: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s0_ps256_bs3_msr"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [256, 256, 256]
batch_size: 3
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 0
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 2
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "MedialSurfaceRecall"
weight: 1.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s2_ps128_bs28_bcedice"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [128, 128, 128]
batch_size: 28
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 2
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 1
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "BinaryBCEAndDiceLoss"
weight: 1.0
weight_bce: 1.0
weight_dice: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s2_ps128_bs28_msr"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [128, 128, 128]
batch_size: 28
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 2
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 2
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "MedialSurfaceRecall"
weight: 1.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s2_ps256_bs3_bcedice"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [256, 256, 256]
batch_size: 3
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 2
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 1
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "BinaryBCEAndDiceLoss"
weight: 1.0
weight_bce: 1.0
weight_dice: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
tr_setup:
wandb_project: "srf_2um"
wandb_entity: "vesuvius-challenge"
model_name: "surface_resenc_s2_ps256_bs3_msr"
tr_val_split: 0.95
autoconfigure: false

tr_config:
patch_size: [256, 256, 256]
batch_size: 3
num_dataloader_workers: 14

model_config:
basic_encoder_block: "BasicBlockD"
bottleneck_block: "BasicBlockD"
basic_decoder_block: "ConvBlock"
norm_op: "nn.InstanceNorm3d"
nonlin: "nn.LeakyReLU"
features_per_stage: [32, 64, 128, 256, 320, 320]
n_stages: 6
n_blocks_per_stage: [1, 3, 4, 6, 6, 6]
n_conv_per_stage_decoder: [1, 1, 1, 1, 1]
kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
separate_decoders: true

dataset_config:
data_path: "/ephemeral/datasets"
ome_zarr_resolution: 2
min_labeled_ratio: 0.001
min_bbox_percent: 0.35
valid_patch_find_resolution: 3
targets:
surface:
out_channels: 2
valid_patch_value: 1
activation: "none"
ignore_label: 2
losses:
- name: "MedialSurfaceRecall"
weight: 1.0

Loading
Loading