Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
/data/brats2017_seg/BraTS2017_Training_Data/*
/data/brats2017_seg/brats2017_raw_data/train/*

/data/brats2021_seg/BraTS2021_Training_Data/*
/data/brats2021_seg/brats2021_raw_data/train


/visualization/*

.vscode/

**/wandb
**/__pycache__
Expand Down
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

2 changes: 2 additions & 0 deletions architectures/build_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
we import each of the architectures into this file. Once we have that
we can use a keyword from the config file to build the model.
"""


######################################################################
def build_architecture(config):
if config["model_name"] == "segformer3d":
Expand Down
75 changes: 75 additions & 0 deletions architectures/configuration_segformer3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from transformers import PretrainedConfig
from typing import List, Union


class SegFormer3DConfig(PretrainedConfig):
model_type = "segformer3d"

def __init__(
self,
in_channels: int = 4,
sr_ratios: List[int] = [4, 2, 1, 1],
embed_dims: List[int] = [32, 64, 160, 256],
patch_kernel_size: List[Union[int, List[int]]] = [
[7, 7, 7],
[3, 3, 3],
[3, 3, 3],
[3, 3, 3],
],
patch_stride: List[Union[int, List[int]]] = [
[4, 4, 4],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
],
patch_padding: List[Union[int, List[int]]] = [
[3, 3, 3],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
],
mlp_ratios: List[int] = [4, 4, 4, 4],
num_heads: List[int] = [1, 2, 5, 8],
depths: List[int] = [2, 2, 2, 2],
decoder_head_embedding_dim: int = 256,
num_classes: int = 3,
decoder_dropout: float = 0.0,
qkv_bias: bool = True,
attention_dropout: float = 0.0,
projection_dropout: float = 0.0,
**kwargs
):
"""
Args:
in_channels (int): Number of input channels
sr_ratios (List[int]): Spatial reduction ratios for each stage
embed_dims (List[int]): Embedding dimensions for each stage
patch_kernel_size (List[int]): Kernel sizes for patch embedding
patch_stride (List[int]): Stride values for patch embedding
patch_padding (List[int]): Padding values for patch embedding
mlp_ratios (List[int]): MLP expansion ratios for each stage
num_heads (List[int]): Number of attention heads for each stage
depths (List[int]): Number of transformer blocks per stage
decoder_head_embedding_dim (int): Embedding dimension in decoder head
num_classes (int): Number of output classes
decoder_dropout (float): Dropout rate in decoder
qkv_bias (bool): Whether to use bias in QKV projections
attention_dropout (float): Dropout rate for attention
projection_dropout (float): Dropout rate for projections
"""
super().__init__(**kwargs)
self.in_channels = in_channels
self.sr_ratios = sr_ratios
self.embed_dims = embed_dims
self.patch_kernel_size = patch_kernel_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.mlp_ratios = mlp_ratios
self.num_heads = num_heads
self.depths = depths
self.decoder_head_embedding_dim = decoder_head_embedding_dim
self.num_classes = num_classes
self.decoder_dropout = decoder_dropout
self.qkv_bias = qkv_bias
self.attention_dropout = attention_dropout
self.projection_dropout = projection_dropout
Loading