diff --git a/.gitignore b/.gitignore index 0d06eb9..55ea5a9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,4 @@ -/data/brats2017_seg/BraTS2017_Training_Data/* -/data/brats2017_seg/brats2017_raw_data/train/* - -/data/brats2021_seg/BraTS2021_Training_Data/* -/data/brats2021_seg/brats2021_raw_data/train - - -/visualization/* - +.vscode/ **/wandb **/__pycache__ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 22f60ba..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "ros.distro": "foxy" -} \ No newline at end of file diff --git a/architectures/build_architecture.py b/architectures/build_architecture.py index 9e19f6b..54f1b0a 100644 --- a/architectures/build_architecture.py +++ b/architectures/build_architecture.py @@ -3,6 +3,8 @@ we import each of the architectures into this file. Once we have that we can use a keyword from the config file to build the model. """ + + ###################################################################### def build_architecture(config): if config["model_name"] == "segformer3d": diff --git a/architectures/configuration_segformer3d.py b/architectures/configuration_segformer3d.py new file mode 100644 index 0000000..dfda9bc --- /dev/null +++ b/architectures/configuration_segformer3d.py @@ -0,0 +1,75 @@ +from transformers import PretrainedConfig +from typing import List, Union + + +class SegFormer3DConfig(PretrainedConfig): + model_type = "segformer3d" + + def __init__( + self, + in_channels: int = 4, + sr_ratios: List[int] = [4, 2, 1, 1], + embed_dims: List[int] = [32, 64, 160, 256], + patch_kernel_size: List[Union[int, List[int]]] = [ + [7, 7, 7], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + patch_stride: List[Union[int, List[int]]] = [ + [4, 4, 4], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + ], + patch_padding: List[Union[int, List[int]]] = [ + [3, 3, 3], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + ], + mlp_ratios: List[int] = [4, 4, 4, 4], + num_heads: List[int] = [1, 2, 5, 8], + depths: List[int] = [2, 2, 2, 2], + decoder_head_embedding_dim: int = 256, + num_classes: int = 3, + decoder_dropout: float = 0.0, + qkv_bias: bool = True, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + **kwargs + ): + """ + Args: + in_channels (int): Number of input channels + sr_ratios (List[int]): Spatial reduction ratios for each stage + embed_dims (List[int]): Embedding dimensions for each stage + patch_kernel_size (List[int]): Kernel sizes for patch embedding + patch_stride (List[int]): Stride values for patch embedding + patch_padding (List[int]): Padding values for patch embedding + mlp_ratios (List[int]): MLP expansion ratios for each stage + num_heads (List[int]): Number of attention heads for each stage + depths (List[int]): Number of transformer blocks per stage + decoder_head_embedding_dim (int): Embedding dimension in decoder head + num_classes (int): Number of output classes + decoder_dropout (float): Dropout rate in decoder + qkv_bias (bool): Whether to use bias in QKV projections + attention_dropout (float): Dropout rate for attention + projection_dropout (float): Dropout rate for projections + """ + super().__init__(**kwargs) + self.in_channels = in_channels + self.sr_ratios = sr_ratios + self.embed_dims = embed_dims + self.patch_kernel_size = patch_kernel_size + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.mlp_ratios = mlp_ratios + self.num_heads = num_heads + self.depths = depths + self.decoder_head_embedding_dim = decoder_head_embedding_dim + self.num_classes = num_classes + self.decoder_dropout = decoder_dropout + self.qkv_bias = qkv_bias + self.attention_dropout = attention_dropout + self.projection_dropout = projection_dropout diff --git a/architectures/modeling_segformer3d.py b/architectures/modeling_segformer3d.py new file mode 100644 index 0000000..3c4024f --- /dev/null +++ b/architectures/modeling_segformer3d.py @@ -0,0 +1,466 @@ +"""SegFormer3D model implementation compatible with HuggingFace Transformers""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.modeling_outputs import SemanticSegmenterOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, +) + +from .configuration_segformer3d import SegFormer3DConfig + + +class PatchEmbedding(nn.Module): + def __init__( + self, + in_channel: int = 4, + embed_dim: int = 768, + kernel_size: Union[int, List[int]] = [7, 7, 7], + stride: Union[int, List[int]] = [4, 4, 4], + padding: Union[int, List[int]] = [3, 3, 3], + ): + super().__init__() + # Convert single integers to lists if necessary + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * 3 + if isinstance(stride, int): + stride = [stride] * 3 + if isinstance(padding, int): + padding = [padding] * 3 + + self.patch_embeddings = nn.Conv3d( + in_channel, + embed_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + self.norm = nn.LayerNorm(embed_dim) + + # Store for shape calculations + self.stride = stride + self.padding = padding + self.kernel_size = kernel_size + + def forward(self, x): + patches = self.patch_embeddings(x) + patches = patches.flatten(2).transpose(1, 2) + patches = self.norm(patches) + return patches + + def get_output_shape(self, input_shape): + """Calculate output spatial dimensions after convolution""" + d, h, w = input_shape + od = ((d + 2 * self.padding[0] - self.kernel_size[0]) // self.stride[0]) + 1 + oh = ((h + 2 * self.padding[1] - self.kernel_size[1]) // self.stride[1]) + 1 + ow = ((w + 2 * self.padding[2] - self.kernel_size[2]) // self.stride[2]) + 1 + return (od, oh, ow) + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, spatial_shape): + B, N, C = x.shape + d, h, w = spatial_shape + x = x.transpose(1, 2).view(B, C, d, h, w) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class SelfAttention(nn.Module): + def __init__( + self, + config: SegFormer3DConfig, + embed_dim: int, + num_heads: int, + sr_ratio: int, + ): + super().__init__() + self.num_heads = num_heads + self.attention_head_dim = embed_dim // num_heads + self.scale = self.attention_head_dim**-0.5 + + self.query = nn.Linear(embed_dim, embed_dim, bias=config.qkv_bias) + self.key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=config.qkv_bias) + self.attn_dropout = nn.Dropout(config.attention_dropout) + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_dropout = nn.Dropout(config.projection_dropout) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv3d( + embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio + ) + self.sr_norm = nn.LayerNorm(embed_dim) + + def forward(self, x, spatial_shape, output_attentions: bool = False): + D, H, W = spatial_shape + B, N, C = x.shape + q = ( + self.query(x) + .reshape(B, N, self.num_heads, self.attention_head_dim) + .permute(0, 2, 1, 3) + ) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, D, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.sr_norm(x_) + kv = ( + self.key_value(x_) + .reshape(B, -1, 2, self.num_heads, self.attention_head_dim) + .permute(2, 0, 3, 1, 4) + ) + else: + kv = ( + self.key_value(x) + .reshape(B, -1, 2, self.num_heads, self.attention_head_dim) + .permute(2, 0, 3, 1, 4) + ) + + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_dropout(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_dropout(x) + + if output_attentions: + return x, attn + return (x,) + + +class MLP(nn.Module): + def __init__(self, embed_dim, hidden_features, dropout): + super().__init__() + self.linear_1 = nn.Linear(embed_dim, hidden_features) + self.conv = DWConv(hidden_features) + self.linear_2 = nn.Linear(hidden_features, embed_dim) + self.activation = nn.GELU() + self.dropout = nn.Dropout(config.projection_dropout) + + def forward(self, x, spatial_shape): + x = self.linear_1(x) + x = self.conv(x, spatial_shape) + x = self.linear_2(x) + x = self.activation(x) + x = self.dropout(x) + return x + + +class TransformerBlock(nn.Module): + def __init__( + self, + config: SegFormer3DConfig, + embed_dim: int, + num_heads: int, + sr_ratio: int, + mlp_ratio: int, + ): + super().__init__() + self.norm1 = nn.LayerNorm(embed_dim) + self.attention = SelfAttention( + config=config, + embed_dim=embed_dim, + num_heads=num_heads, + sr_ratio=sr_ratio, + ) + self.norm2 = nn.LayerNorm(embed_dim) + hidden_features = int(embed_dim * mlp_ratio) + + self.mlp = MLP(embed_dim, hidden_features, config.projection_dropout) + + def forward(self, x, spatial_shape, output_attentions: bool = False): + attention_outputs = self.attention( + self.norm1(x), spatial_shape, output_attentions + ) + x = x + attention_outputs[0] + x = x + self.mlp(self.norm2(x), spatial_shape) + + outputs = (x,) + attention_outputs[1:] if output_attentions else (x,) + return outputs, spatial_shape + + +@add_start_docstrings("""SegFormer3D Model for 3D semantic segmentation.""") +class SegFormer3DPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading + and loading pretrained models. + """ + + config_class = SegFormer3DConfig + base_model_prefix = "segformer3d" + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, nn.BatchNorm2d): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, nn.BatchNorm3d): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, nn.Conv2d): + fan_out = ( + module.kernel_size[0] * module.kernel_size[1] * module.out_channels + ) + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Conv3d): + fan_out = ( + module.kernel_size[0] + * module.kernel_size[1] + * module.kernel_size[2] + * module.out_channels + ) + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() + + +@add_start_docstrings("""SegFormer3D Model for 3D semantic segmentation tasks.""") +class SegFormer3DModel(SegFormer3DPreTrainedModel): + def __init__(self, config: SegFormer3DConfig): + super().__init__(config) + + # Encoder components + self.encoders = nn.ModuleList() + self.encoder_norms = nn.ModuleList() + self.transformer_blocks = nn.ModuleList() + + # Build hierarchical encoder stages + for i in range(4): + # Patch embedding for this stage + encoder = PatchEmbedding( + in_channel=config.in_channels if i == 0 else config.embed_dims[i - 1], + embed_dim=config.embed_dims[i], + kernel_size=config.patch_kernel_size[i], + stride=config.patch_stride[i], + padding=config.patch_padding[i], + ) + self.encoders.append(encoder) + + # Transformer blocks for this stage + stage_blocks = nn.ModuleList( + [ + TransformerBlock( + config=config, + embed_dim=config.embed_dims[i], + num_heads=config.num_heads[i], + sr_ratio=config.sr_ratios[i], + mlp_ratio=config.mlp_ratios[i], + ) + for _ in range(config.depths[i]) + ] + ) + self.transformer_blocks.append(stage_blocks) + + # Layer norm for this stage + self.encoder_norms.append(nn.LayerNorm(config.embed_dims[i])) + + # Decoder components + self.decoder = SegFormer3DDecoderHead(config) + + # Initialize weights + self.post_init() + + def get_input_embeddings(self): + return self.encoders[0].patch_embeddings + + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SemanticSegmenterOutput]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + encoder_hidden_states = [] + all_attentions = [] if output_attentions else None + + x = pixel_values + + # Process through encoder stages + for stage_idx in range(4): + # Track input spatial dimensions + if stage_idx == 0: + spatial_shape = pixel_values.shape[2:] # (D, H, W) + + # Patch embedding + x = self.encoders[stage_idx](x) + spatial_shape = self.encoders[stage_idx].get_output_shape(spatial_shape) + B, N, C = x.shape + + # Transformer blocks + for block in self.transformer_blocks[stage_idx]: + block_outputs, spatial_shape = block( + x, spatial_shape, output_attentions + ) + x = block_outputs[0] + if output_attentions: + all_attentions.append(block_outputs[1]) + + # Layer norm + x = self.encoder_norms[stage_idx](x) + + # Reshape and store hidden state using calculated dimensions + d, h, w = spatial_shape + x_reshaped = x.reshape(B, d, h, w, -1).permute(0, 4, 1, 2, 3).contiguous() + encoder_hidden_states.append(x_reshaped) + + # Prepare input for next stage if not last stage + if stage_idx < 3: + x = x_reshaped + + # Decode features + logits = self.decoder(encoder_hidden_states) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + # compute loss for 3D semantic segmentation + loss_fct = CrossEntropyLoss(ignore_index=255) + loss = loss_fct( + logits.view(-1, self.config.num_classes), labels.view(-1) + ) + + if not return_dict: + outputs = ( + (logits,) + + (encoder_hidden_states if output_hidden_states else ()) + + (all_attentions if output_attentions else ()) + ) + return ((loss,) + outputs) if loss is not None else outputs + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=encoder_hidden_states if output_hidden_states else None, + attentions=all_attentions, + ) + + +class SegFormer3DDecoderHead(nn.Module): + def __init__(self, config: SegFormer3DConfig): + super().__init__() + + # Linear layers for each encoder stage + self.linear_layers = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(dim, config.decoder_head_embedding_dim), + nn.LayerNorm(config.decoder_head_embedding_dim), + ) + for dim in config.embed_dims[::-1] + ] + ) + + # Feature fusion + self.linear_fuse = nn.Sequential( + nn.Conv3d( + in_channels=4 * config.decoder_head_embedding_dim, + out_channels=config.decoder_head_embedding_dim, + kernel_size=1, + bias=False, + ), + nn.BatchNorm3d(config.decoder_head_embedding_dim), + nn.ReLU(), + ) + + self.dropout = nn.Dropout(config.decoder_dropout) + self.linear_pred = nn.Conv3d( + config.decoder_head_embedding_dim, config.num_classes, kernel_size=1 + ) + self.upsample = nn.Upsample( + scale_factor=4, mode="trilinear", align_corners=False + ) + + def forward(self, encoder_hidden_states): + # Process features from each encoder stage + B = encoder_hidden_states[-1].shape[0] + + # Linear projection and upsampling of each stage's features + decoded_features = [] + for i, features in enumerate( + encoder_hidden_states[::-1] + ): # Process in reverse order + d, h, w = features.shape[2:] + projected = ( + self.linear_layers[i](features.flatten(2).transpose(1, 2)) + .transpose(1, 2) + .reshape(B, -1, d, h, w) + ) + + # Upsample if not the last feature map + if i != len(encoder_hidden_states[::-1]): + projected = torch.nn.functional.interpolate( + projected, + size=encoder_hidden_states[0].shape[ + 2: + ], # Size of first stage features + mode="trilinear", + align_corners=False, + ) + decoded_features.append(projected) + + # Fuse all features + fused_features = self.linear_fuse(torch.cat(decoded_features, dim=1)) + + # Final prediction + x = self.dropout(fused_features) + x = self.linear_pred(x) + x = self.upsample(x) + + return x + + +if __name__ == "__main__": + input = torch.randint( + low=0, + high=255, + size=(1, 4, 128, 128, 128), + dtype=torch.float, + ) + input = input.to("cuda:0") + config = SegFormer3DConfig() + segformer3D = SegFormer3DModel(config).to("cuda:0") + output = segformer3D(input) + print(output["logits"].shape) diff --git a/architectures/segformer3d.py b/architectures/segformer3d.py deleted file mode 100644 index 6355148..0000000 --- a/architectures/segformer3d.py +++ /dev/null @@ -1,648 +0,0 @@ -import torch -import math -import copy -from torch import nn -from einops import rearrange -from functools import partial - -def build_segformer3d_model(config=None): - model = SegFormer3D( - in_channels=config["model_parameters"]["in_channels"], - sr_ratios=config["model_parameters"]["sr_ratios"], - embed_dims=config["model_parameters"]["embed_dims"], - patch_kernel_size=config["model_parameters"]["patch_kernel_size"], - patch_stride=config["model_parameters"]["patch_stride"], - patch_padding=config["model_parameters"]["patch_padding"], - mlp_ratios=config["model_parameters"]["mlp_ratios"], - num_heads=config["model_parameters"]["num_heads"], - depths=config["model_parameters"]["depths"], - decoder_head_embedding_dim=config["model_parameters"][ - "decoder_head_embedding_dim" - ], - num_classes=config["model_parameters"]["num_classes"], - decoder_dropout=config["model_parameters"]["decoder_dropout"], - ) - return model - - -class SegFormer3D(nn.Module): - def __init__( - self, - in_channels: int = 4, - sr_ratios: list = [4, 2, 1, 1], - embed_dims: list = [32, 64, 160, 256], - patch_kernel_size: list = [7, 3, 3, 3], - patch_stride: list = [4, 2, 2, 2], - patch_padding: list = [3, 1, 1, 1], - mlp_ratios: list = [4, 4, 4, 4], - num_heads: list = [1, 2, 5, 8], - depths: list = [2, 2, 2, 2], - decoder_head_embedding_dim: int = 256, - num_classes: int = 3, - decoder_dropout: float = 0.0, - ): - """ - in_channels: number of the input channels - img_volume_dim: spatial resolution of the image volume (Depth, Width, Height) - sr_ratios: the rates at which to down sample the sequence length of the embedded patch - embed_dims: hidden size of the PatchEmbedded input - patch_kernel_size: kernel size for the convolution in the patch embedding module - patch_stride: stride for the convolution in the patch embedding module - patch_padding: padding for the convolution in the patch embedding module - mlp_ratios: at which rate increases the projection dim of the hidden_state in the mlp - num_heads: number of attention heads - depths: number of attention layers - decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module - num_classes: number of the output channel of the network - decoder_dropout: dropout rate of the concatenated feature maps - - """ - super().__init__() - self.segformer_encoder = MixVisionTransformer( - in_channels=in_channels, - sr_ratios=sr_ratios, - embed_dims=embed_dims, - patch_kernel_size=patch_kernel_size, - patch_stride=patch_stride, - patch_padding=patch_padding, - mlp_ratios=mlp_ratios, - num_heads=num_heads, - depths=depths, - ) - # decoder takes in the feature maps in the reversed order - reversed_embed_dims = embed_dims[::-1] - self.segformer_decoder = SegFormerDecoderHead( - input_feature_dims=reversed_embed_dims, - decoder_head_embedding_dim=decoder_head_embedding_dim, - num_classes=num_classes, - dropout=decoder_dropout, - ) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.BatchNorm3d): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Conv3d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - - def forward(self, x): - # embedding the input - x = self.segformer_encoder(x) - # # unpacking the embedded features generated by the transformer - c1 = x[0] - c2 = x[1] - c3 = x[2] - c4 = x[3] - # decoding the embedded features - x = self.segformer_decoder(c1, c2, c3, c4) - return x - -# ----------------------------------------------------- encoder ----------------------------------------------------- -class PatchEmbedding(nn.Module): - def __init__( - self, - in_channel: int = 4, - embed_dim: int = 768, - kernel_size: int = 7, - stride: int = 4, - padding: int = 3, - ): - """ - in_channels: number of the channels in the input volume - embed_dim: embedding dimmesion of the patch - """ - super().__init__() - self.patch_embeddings = nn.Conv3d( - in_channel, - embed_dim, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) - self.norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - # standard embedding patch - patches = self.patch_embeddings(x) - patches = patches.flatten(2).transpose(1, 2) - patches = self.norm(patches) - return patches - - -class SelfAttention(nn.Module): - def __init__( - self, - embed_dim: int = 768, - num_heads: int = 8, - sr_ratio: int = 2, - qkv_bias: bool = False, - attn_dropout: float = 0.0, - proj_dropout: float = 0.0, - ): - """ - embed_dim : hidden size of the PatchEmbedded input - num_heads: number of attention heads - sr_ratio: the rate at which to down sample the sequence length of the embedded patch - qkv_bias: whether or not the linear projection has bias - attn_dropout: the dropout rate of the attention component - proj_dropout: the dropout rate of the final linear projection - """ - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), "Embedding dim should be divisible by number of heads!" - - self.num_heads = num_heads - # embedding dimesion of each attention head - self.attention_head_dim = embed_dim // num_heads - - # The same input is used to generate the query, key, and value, - # (batch_size, num_patches, hidden_size) -> (batch_size, num_patches, attention_head_size) - self.query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) - self.key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=qkv_bias) - self.attn_dropout = nn.Dropout(attn_dropout) - self.proj = nn.Linear(embed_dim, embed_dim) - self.proj_dropout = nn.Dropout(proj_dropout) - - self.sr_ratio = sr_ratio - if sr_ratio > 1: - self.sr = nn.Conv3d( - embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio - ) - self.sr_norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - # (batch_size, num_patches, hidden_size) - B, N, C = x.shape - - # (batch_size, num_head, sequence_length, embed_dim) - q = ( - self.query(x) - .reshape(B, N, self.num_heads, self.attention_head_dim) - .permute(0, 2, 1, 3) - ) - - if self.sr_ratio > 1: - n = cube_root(N) - # (batch_size, sequence_length, embed_dim) -> (batch_size, embed_dim, patch_D, patch_H, patch_W) - x_ = x.permute(0, 2, 1).reshape(B, C, n, n, n) - # (batch_size, embed_dim, patch_D, patch_H, patch_W) -> (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - # (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) -> (batch_size, sequence_length, embed_dim) - # normalizing the layer - x_ = self.sr_norm(x_) - # (batch_size, num_patches, hidden_size) - kv = ( - self.key_value(x_) - .reshape(B, -1, 2, self.num_heads, self.attention_head_dim) - .permute(2, 0, 3, 1, 4) - ) - # (2, batch_size, num_heads, num_sequence, attention_head_dim) - else: - # (batch_size, num_patches, hidden_size) - kv = ( - self.key_value(x) - .reshape(B, -1, 2, self.num_heads, self.attention_head_dim) - .permute(2, 0, 3, 1, 4) - ) - # (2, batch_size, num_heads, num_sequence, attention_head_dim) - - k, v = kv[0], kv[1] - - attention_score = (q @ k.transpose(-2, -1)) / math.sqrt(self.num_heads) - attnention_prob = attention_score.softmax(dim=-1) - attnention_prob = self.attn_dropout(attnention_prob) - out = (attnention_prob @ v).transpose(1, 2).reshape(B, N, C) - out = self.proj(out) - out = self.proj_dropout(out) - return out - - -class TransformerBlock(nn.Module): - def __init__( - self, - embed_dim: int = 768, - mlp_ratio: int = 2, - num_heads: int = 8, - sr_ratio: int = 2, - qkv_bias: bool = False, - attn_dropout: float = 0.0, - proj_dropout: float = 0.0, - ): - """ - embed_dim : hidden size of the PatchEmbedded input - mlp_ratio: at which rate increasse the projection dim of the embedded patch in the _MLP component - num_heads: number of attention heads - sr_ratio: the rate at which to down sample the sequence length of the embedded patch - qkv_bias: whether or not the linear projection has bias - attn_dropout: the dropout rate of the attention component - proj_dropout: the dropout rate of the final linear projection - """ - super().__init__() - self.norm1 = nn.LayerNorm(embed_dim) - self.attention = SelfAttention( - embed_dim=embed_dim, - num_heads=num_heads, - sr_ratio=sr_ratio, - qkv_bias=qkv_bias, - attn_dropout=attn_dropout, - proj_dropout=proj_dropout, - ) - self.norm2 = nn.LayerNorm(embed_dim) - self.mlp = _MLP(in_feature=embed_dim, mlp_ratio=mlp_ratio, dropout=0.0) - - def forward(self, x): - x = x + self.attention(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - - -class MixVisionTransformer(nn.Module): - def __init__( - self, - in_channels: int = 4, - sr_ratios: list = [8, 4, 2, 1], - embed_dims: list = [64, 128, 320, 512], - patch_kernel_size: list = [7, 3, 3, 3], - patch_stride: list = [4, 2, 2, 2], - patch_padding: list = [3, 1, 1, 1], - mlp_ratios: list = [2, 2, 2, 2], - num_heads: list = [1, 2, 5, 8], - depths: list = [2, 2, 2, 2], - ): - """ - in_channels: number of the input channels - img_volume_dim: spatial resolution of the image volume (Depth, Width, Height) - sr_ratios: the rates at which to down sample the sequence length of the embedded patch - embed_dims: hidden size of the PatchEmbedded input - patch_kernel_size: kernel size for the convolution in the patch embedding module - patch_stride: stride for the convolution in the patch embedding module - patch_padding: padding for the convolution in the patch embedding module - mlp_ratio: at which rate increasse the projection dim of the hidden_state in the mlp - num_heads: number of attenion heads - depth: number of attention layers - """ - super().__init__() - - # patch embedding at different Pyramid level - self.embed_1 = PatchEmbedding( - in_channel=in_channels, - embed_dim=embed_dims[0], - kernel_size=patch_kernel_size[0], - stride=patch_stride[0], - padding=patch_padding[0], - ) - self.embed_2 = PatchEmbedding( - in_channel=embed_dims[0], - embed_dim=embed_dims[1], - kernel_size=patch_kernel_size[1], - stride=patch_stride[1], - padding=patch_padding[1], - ) - self.embed_3 = PatchEmbedding( - in_channel=embed_dims[1], - embed_dim=embed_dims[2], - kernel_size=patch_kernel_size[2], - stride=patch_stride[2], - padding=patch_padding[2], - ) - self.embed_4 = PatchEmbedding( - in_channel=embed_dims[2], - embed_dim=embed_dims[3], - kernel_size=patch_kernel_size[3], - stride=patch_stride[3], - padding=patch_padding[3], - ) - - # block 1 - self.tf_block1 = nn.ModuleList( - [ - TransformerBlock( - embed_dim=embed_dims[0], - num_heads=num_heads[0], - mlp_ratio=mlp_ratios[0], - sr_ratio=sr_ratios[0], - qkv_bias=True, - ) - for _ in range(depths[0]) - ] - ) - self.norm1 = nn.LayerNorm(embed_dims[0]) - - # block 2 - self.tf_block2 = nn.ModuleList( - [ - TransformerBlock( - embed_dim=embed_dims[1], - num_heads=num_heads[1], - mlp_ratio=mlp_ratios[1], - sr_ratio=sr_ratios[1], - qkv_bias=True, - ) - for _ in range(depths[1]) - ] - ) - self.norm2 = nn.LayerNorm(embed_dims[1]) - - # block 3 - self.tf_block3 = nn.ModuleList( - [ - TransformerBlock( - embed_dim=embed_dims[2], - num_heads=num_heads[2], - mlp_ratio=mlp_ratios[2], - sr_ratio=sr_ratios[2], - qkv_bias=True, - ) - for _ in range(depths[2]) - ] - ) - self.norm3 = nn.LayerNorm(embed_dims[2]) - - # block 4 - self.tf_block4 = nn.ModuleList( - [ - TransformerBlock( - embed_dim=embed_dims[3], - num_heads=num_heads[3], - mlp_ratio=mlp_ratios[3], - sr_ratio=sr_ratios[3], - qkv_bias=True, - ) - for _ in range(depths[3]) - ] - ) - self.norm4 = nn.LayerNorm(embed_dims[3]) - - def forward(self, x): - out = [] - # at each stage these are the following mappings: - # (batch_size, num_patches, hidden_state) - # (num_patches,) -> (D, H, W) - # (batch_size, num_patches, hidden_state) -> (batch_size, hidden_state, D, H, W) - - # stage 1 - x = self.embed_1(x) - B, N, C = x.shape - n = cube_root(N) - for i, blk in enumerate(self.tf_block1): - x = blk(x) - x = self.norm1(x) - # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) - x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() - out.append(x) - - # stage 2 - x = self.embed_2(x) - B, N, C = x.shape - n = cube_root(N) - for i, blk in enumerate(self.tf_block2): - x = blk(x) - x = self.norm2(x) - # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) - x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() - out.append(x) - - # stage 3 - x = self.embed_3(x) - B, N, C = x.shape - n = cube_root(N) - for i, blk in enumerate(self.tf_block3): - x = blk(x) - x = self.norm3(x) - # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) - x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() - out.append(x) - - # stage 4 - x = self.embed_4(x) - B, N, C = x.shape - n = cube_root(N) - for i, blk in enumerate(self.tf_block4): - x = blk(x) - x = self.norm4(x) - # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) - x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() - out.append(x) - - return out - - -class _MLP(nn.Module): - def __init__(self, in_feature, mlp_ratio=2, dropout=0.0): - super().__init__() - out_feature = mlp_ratio * in_feature - self.fc1 = nn.Linear(in_feature, out_feature) - self.dwconv = DWConv(dim=out_feature) - self.fc2 = nn.Linear(out_feature, in_feature) - self.act_fn = nn.GELU() - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.fc1(x) - x = self.dwconv(x) - x = self.act_fn(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - - -class DWConv(nn.Module): - def __init__(self, dim=768): - super().__init__() - self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim) - # added batchnorm (remove it ?) - self.bn = nn.BatchNorm3d(dim) - - def forward(self, x): - B, N, C = x.shape - # (batch, patch_cube, hidden_size) -> (batch, hidden_size, D, H, W) - # assuming D = H = W, i.e. cube root of the patch is an integer number! - n = cube_root(N) - x = x.transpose(1, 2).view(B, C, n, n, n) - x = self.dwconv(x) - # added batchnorm (remove it ?) - x = self.bn(x) - x = x.flatten(2).transpose(1, 2) - return x - -################################################################################### -def cube_root(n): - return round(math.pow(n, (1 / 3))) - - -################################################################################### -# ----------------------------------------------------- decoder ------------------- -class MLP_(nn.Module): - """ - Linear Embedding - """ - - def __init__(self, input_dim=2048, embed_dim=768): - super().__init__() - self.proj = nn.Linear(input_dim, embed_dim) - self.bn = nn.LayerNorm(embed_dim) - - def forward(self, x): - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.proj(x) - # added batchnorm (remove it ?) - x = self.bn(x) - return x - - -################################################################################### -class SegFormerDecoderHead(nn.Module): - """ - SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers - """ - - def __init__( - self, - input_feature_dims: list = [512, 320, 128, 64], - decoder_head_embedding_dim: int = 256, - num_classes: int = 3, - dropout: float = 0.0, - ): - """ - input_feature_dims: list of the output features channels generated by the transformer encoder - decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module - num_classes: number of the output channels - dropout: dropout rate of the concatenated feature maps - """ - super().__init__() - self.linear_c4 = MLP_( - input_dim=input_feature_dims[0], - embed_dim=decoder_head_embedding_dim, - ) - self.linear_c3 = MLP_( - input_dim=input_feature_dims[1], - embed_dim=decoder_head_embedding_dim, - ) - self.linear_c2 = MLP_( - input_dim=input_feature_dims[2], - embed_dim=decoder_head_embedding_dim, - ) - self.linear_c1 = MLP_( - input_dim=input_feature_dims[3], - embed_dim=decoder_head_embedding_dim, - ) - # convolution module to combine feature maps generated by the mlps - self.linear_fuse = nn.Sequential( - nn.Conv3d( - in_channels=4 * decoder_head_embedding_dim, - out_channels=decoder_head_embedding_dim, - kernel_size=1, - stride=1, - bias=False, - ), - nn.BatchNorm3d(decoder_head_embedding_dim), - nn.ReLU(), - ) - self.dropout = nn.Dropout(dropout) - - # final linear projection layer - self.linear_pred = nn.Conv3d( - decoder_head_embedding_dim, num_classes, kernel_size=1 - ) - - # segformer decoder generates the final decoded feature map size at 1/4 of the original input volume size - self.upsample_volume = nn.Upsample( - scale_factor=4.0, mode="trilinear", align_corners=False - ) - - def forward(self, c1, c2, c3, c4): - ############## _MLP decoder on C1-C4 ########### - n, _, _, _, _ = c4.shape - - _c4 = ( - self.linear_c4(c4) - .permute(0, 2, 1) - .reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4]) - .contiguous() - ) - _c4 = torch.nn.functional.interpolate( - _c4, - size=c1.size()[2:], - mode="trilinear", - align_corners=False, - ) - - _c3 = ( - self.linear_c3(c3) - .permute(0, 2, 1) - .reshape(n, -1, c3.shape[2], c3.shape[3], c3.shape[4]) - .contiguous() - ) - _c3 = torch.nn.functional.interpolate( - _c3, - size=c1.size()[2:], - mode="trilinear", - align_corners=False, - ) - - _c2 = ( - self.linear_c2(c2) - .permute(0, 2, 1) - .reshape(n, -1, c2.shape[2], c2.shape[3], c2.shape[4]) - .contiguous() - ) - _c2 = torch.nn.functional.interpolate( - _c2, - size=c1.size()[2:], - mode="trilinear", - align_corners=False, - ) - - _c1 = ( - self.linear_c1(c1) - .permute(0, 2, 1) - .reshape(n, -1, c1.shape[2], c1.shape[3], c1.shape[4]) - .contiguous() - ) - - _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) - - x = self.dropout(_c) - x = self.linear_pred(x) - x = self.upsample_volume(x) - return x - -################################################################################### -if __name__ == "__main__": - input = torch.randint( - low=0, - high=255, - size=(1, 4, 128, 128, 128), - dtype=torch.float, - ) - input = input.to("cuda:0") - segformer3D = SegFormer3D().to("cuda:0") - output = segformer3D(input) - print(output.shape) - - -################################################################################### diff --git a/augmentations/augmentations.py b/augmentations/augmentations.py index 0503de3..8c5d9b6 100644 --- a/augmentations/augmentations.py +++ b/augmentations/augmentations.py @@ -1,13 +1,32 @@ import monai.transforms as transforms + ####################################################################################### def build_augmentations(train: bool = True): if train: train_transform = [ - transforms.RandSpatialCropSamplesd(keys=["image", "label"], roi_size=(96, 96, 96), num_samples=4, random_center=True, random_size=False), + transforms.RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=(96, 96, 96), + num_samples=4, + random_center=True, + random_size=False, + ), transforms.RandFlipd(keys=["image", "label"], prob=0.30, spatial_axis=1), - transforms.RandRotated(keys=["image", "label"], prob=0.50, range_x=0.36, range_y=0.0, range_z=0.0), - transforms.RandCoarseDropoutd(keys=["image", "label"], holes=20, spatial_size=(-1, 7, 7), fill_value=0, prob=0.5), + transforms.RandRotated( + keys=["image", "label"], + prob=0.50, + range_x=0.36, + range_y=0.0, + range_z=0.0, + ), + transforms.RandCoarseDropoutd( + keys=["image", "label"], + holes=20, + spatial_size=(-1, 7, 7), + fill_value=0, + prob=0.5, + ), transforms.GibbsNoised(keys=["image"]), transforms.EnsureTyped(keys=["image", "label"], track_meta=False), ] diff --git a/data/brats2017_seg/brats2017_raw_data/brats2017_seg_preprocess.py b/data/brats2017_seg/brats2017_raw_data/brats2017_seg_preprocess.py index 9d34e47..3868847 100644 --- a/data/brats2017_seg/brats2017_raw_data/brats2017_seg_preprocess.py +++ b/data/brats2017_seg/brats2017_raw_data/brats2017_seg_preprocess.py @@ -8,7 +8,7 @@ from matplotlib import animation from monai.data import MetaTensor from multiprocessing import Process, Pool -from sklearn.preprocessing import MinMaxScaler +from sklearn.preprocessing import MinMaxScaler from monai.transforms import ( Orientation, EnsureType, @@ -40,25 +40,37 @@ │ │ └──... """ + + class ConvertToMultiChannelBasedOnBrats2017Classes(object): """ Convert labels to multi channels based on brats17 classes: - "0": "background", + "0": "background", "1": "edema", "2": "non-enhancing tumor", "3": "enhancing tumour" Annotations comprise the GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2), and the necrotic and non-enhancing tumor (NCR/NET — label 1) """ + def __call__(self, img): # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = img.squeeze(0) - result = [(img == 2) | (img == 3), (img == 2) | (img == 3) | (img == 1), img == 3] + result = [ + (img == 2) | (img == 3), + (img == 2) | (img == 3) | (img == 1), + img == 3, + ] # merge labels 1 (tumor non-enh) and 3 (tumor enh) and 1 (large edema) to WT # label 3 is ET - return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) + return ( + torch.stack(result, dim=0) + if isinstance(img, torch.Tensor) + else np.stack(result, axis=0) + ) + class Brats2017Task1Preprocess: def __init__( @@ -78,19 +90,24 @@ def __init__( label_folder_dir = os.path.join(root_dir, train_folder_name, "labelsTr") assert os.path.exists(self.train_folder_dir) assert os.path.exists(label_folder_dir) - + self.save_dir = save_dir - # we only care about case names for which we have label! + # we only care about case names for which we have label! self.case_name = next(os.walk(label_folder_dir), (None, None, []))[2] - - # MRI type - self.MRI_CODE = {"Flair": "0000", "T1w": "0001", "T1gd": "0002", "T2w": "0003", "label": None} + # MRI type + self.MRI_CODE = { + "Flair": "0000", + "T1w": "0001", + "T1gd": "0002", + "T2w": "0003", + "label": None, + } def __len__(self): return self.case_name.__len__() - def normalize(self, x:np.ndarray)->np.ndarray: + def normalize(self, x: np.ndarray) -> np.ndarray: # Transform features by scaling each feature to a given range. scaler = MinMaxScaler(feature_range=(0, 1)) # (H, W, D) -> (H * W, D) @@ -107,12 +124,12 @@ def detach_meta(self, x: MetaTensor) -> np.ndarray: assert type(x) == MetaTensor return EnsureType(data_type="numpy", track_meta=False)(x) - def crop_brats2021_zero_pixels(self, x: np.ndarray)->np.ndarray: + def crop_brats2021_zero_pixels(self, x: np.ndarray) -> np.ndarray: # get rid of the zero pixels around mri scan and cut it so that the region is useful # crop (240, 240, 155) to (128, 128, 128) return x[:, 56:184, 56:184, 13:141] - def remove_case_name_artifact(self, case_name: str)->str: + def remove_case_name_artifact(self, case_name: str) -> str: # BRATS_066.nii.gz -> BRATS_066 return case_name.rsplit(".")[0] @@ -161,24 +178,26 @@ def _2metaTensor(self, nifti_data: np.ndarray, affine_mat: np.ndarray): scan = scan.view(1, D, H, W) return scan - def preprocess_brats_modality(self, data_fp: str, is_label: bool = False)->np.ndarray: + def preprocess_brats_modality( + self, data_fp: str, is_label: bool = False + ) -> np.ndarray: """ apply preprocess stage to the modality data_fp: directory to the modality """ data, affine = self.load_nifti(data_fp) - # label do not the be normalized + # label do not the be normalized if is_label: # Binary mask does not need to be float64! For saving storage purposes! data = data.astype(np.uint8) - # categorical -> one-hot-encoded + # categorical -> one-hot-encoded # (240, 240, 155) -> (3, 240, 240, 155) data = ConvertToMultiChannelBasedOnBrats2017Classes()(data) else: data = self.normalize(x=data) # (240, 240, 155) -> (1, 240, 240, 155) data = data[np.newaxis, ...] - + data = MetaTensor(x=data, affine=affine) # for oreinting the coordinate system we need the affine matrix data = self.orient(data) @@ -194,39 +213,35 @@ def __getitem__(self, idx): # BRATS_001_0000 case_name = self.remove_case_name_artifact(case_name) - # preprocess Flair modality code = self.MRI_CODE["Flair"] flair = self.get_modality_fp(case_name, "imagesTr", code) Flair = self.preprocess_brats_modality(flair, is_label=False) - flair_transv = Flair.swapaxes(1, 3) # transverse plane + flair_transv = Flair.swapaxes(1, 3) # transverse plane - # preprocess T1w modality code = self.MRI_CODE["T1w"] t1w = self.get_modality_fp(case_name, "imagesTr", code) t1w = self.preprocess_brats_modality(t1w, is_label=False) - t1w_transv = t1w.swapaxes(1, 3) # transverse plane - + t1w_transv = t1w.swapaxes(1, 3) # transverse plane + # preprocess T1gd modality code = self.MRI_CODE["T1gd"] t1gd = self.get_modality_fp(case_name, "imagesTr", code) t1gd = self.preprocess_brats_modality(t1gd, is_label=False) - t1gd_transv = t1gd.swapaxes(1, 3) # transverse plane + t1gd_transv = t1gd.swapaxes(1, 3) # transverse plane - # preprocess T2w code = self.MRI_CODE["T2w"] t2w = self.get_modality_fp(case_name, "imagesTr", code) t2w = self.preprocess_brats_modality(t2w, is_label=False) - t2w_transv = t2w.swapaxes(1, 3) # transverse plane - + t2w_transv = t2w.swapaxes(1, 3) # transverse plane # preprocess segmentation label code = self.MRI_CODE["label"] label = self.get_modality_fp(case_name, "labelsTr", code) label = self.preprocess_brats_modality(label, is_label=True) - label = label.swapaxes(1, 3) # transverse plane + label = label.swapaxes(1, 3) # transverse plane # stack modalities (4, D, H, W) modalities = np.concatenate( @@ -234,9 +249,8 @@ def __getitem__(self, idx): axis=0, dtype=np.float32, ) - - return modalities, label, case_name + return modalities, label, case_name def __call__(self): print("started preprocessing Brats2017...") @@ -260,7 +274,6 @@ def process(self, idx): torch.save(label, label_fn) - def animate(input_1, input_2): """animate pairs of image sequences of the same length on two conjugate axis""" assert len(input_1) == len( @@ -288,7 +301,8 @@ def animate(input_1, input_2): repeat_delay=100, ) -def viz(volume_indx: int = 1, label_indx: int = 1)->None: + +def viz(volume_indx: int = 1, label_indx: int = 1) -> None: """ pair visualization of the volume and label volume_indx: index for the volume. ["Flair", "t1", "t1ce", "t2"] @@ -303,15 +317,12 @@ def viz(volume_indx: int = 1, label_indx: int = 1)->None: if __name__ == "__main__": - brats2017_task1_prep = Brats2017Task1Preprocess(root_dir="./", - train_folder_name = "train", - save_dir="../BraTS2017_Training_Data" + brats2017_task1_prep = Brats2017Task1Preprocess( + root_dir="./", train_folder_name="train", save_dir="../BraTS2017_Training_Data" ) - # run the preprocessing pipeline + # run the preprocessing pipeline brats2017_task1_prep() - # in case you want to visualize the data you can uncomment the following. Change the index to see different data + # in case you want to visualize the data you can uncomment the following. Change the index to see different data # volume, label, case_name = brats2017_task1_prep[400] # viz(volume_indx = 3, label_indx = 1) - - diff --git a/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_kfold_csv.py b/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_kfold_csv.py index 404fca5..e03e9e0 100644 --- a/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_kfold_csv.py +++ b/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_kfold_csv.py @@ -27,7 +27,7 @@ def save_pandas_df(dataframe: pd.DataFrame, save_path: str, header: list) -> Non """ assert save_path.endswith("csv") assert isinstance(dataframe, pd.DataFrame) - assert (dataframe.columns.__len__() == header.__len__()) + assert dataframe.columns.__len__() == header.__len__() dataframe.to_csv(path_or_buf=save_path, header=header, index=False) diff --git a/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_test_csv.py b/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_test_csv.py index aeb746a..fbec7cc 100644 --- a/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_test_csv.py +++ b/data/brats2017_seg/brats2017_raw_data/datameta_generator/create_train_val_test_csv.py @@ -57,7 +57,7 @@ def create_train_val_test_csv_from_data_folder( train_sample_base_dir = np.array(data_dir)[train_idx] train_sample_case_name = np.array(case_name)[train_idx] - # we do not need test split so we can merge it with validation + # we do not need test split so we can merge it with validation val_idx = np.concatenate((val_idx, test_idx), axis=0) validation_sample_base_dir = np.array(data_dir)[val_idx] validation_sample_case_name = np.array(case_name)[val_idx] @@ -91,7 +91,6 @@ def create_train_val_test_csv_from_data_folder( ) - if __name__ == "__main__": create_train_val_test_csv_from_data_folder( # path to the train data folder diff --git a/data/brats2017_seg/brats2017_raw_data/datameta_generator/nnformer_train_test_split.py b/data/brats2017_seg/brats2017_raw_data/datameta_generator/nnformer_train_test_split.py index dfef98a..df30da5 100644 --- a/data/brats2017_seg/brats2017_raw_data/datameta_generator/nnformer_train_test_split.py +++ b/data/brats2017_seg/brats2017_raw_data/datameta_generator/nnformer_train_test_split.py @@ -470,31 +470,30 @@ "BRATS_397", ] test_split = [ - "BRATS_058", - "BRATS_059", - "BRATS_076", - "BRATS_077", - "BRATS_099", - "BRATS_113", - "BRATS_114", - "BRATS_124", - "BRATS_139", - "BRATS_151", - "BRATS_152", - "BRATS_157", - "BRATS_190", - "BRATS_240", - "BRATS_242", - "BRATS_295", - "BRATS_305", - "BRATS_325", - "BRATS_331", - "BRATS_362", - "BRATS_389", - "BRATS_425", - "BRATS_432", - "BRATS_450" - + "BRATS_058", + "BRATS_059", + "BRATS_076", + "BRATS_077", + "BRATS_099", + "BRATS_113", + "BRATS_114", + "BRATS_124", + "BRATS_139", + "BRATS_151", + "BRATS_152", + "BRATS_157", + "BRATS_190", + "BRATS_240", + "BRATS_242", + "BRATS_295", + "BRATS_305", + "BRATS_325", + "BRATS_331", + "BRATS_362", + "BRATS_389", + "BRATS_425", + "BRATS_432", + "BRATS_450", ] train = [] @@ -509,14 +508,14 @@ data={"base_dir": train, "case_name": train_split}, index=None, columns=None, - ) +) # create a pandas data frame validation_df = pd.DataFrame( data={"base_dir": test, "case_name": test_split}, index=None, columns=None, - ) +) # write csv files to the drive! diff --git a/data/brats2021_seg/brats2021_raw_data/brats2021_seg_preprocess.py b/data/brats2021_seg/brats2021_raw_data/brats2021_seg_preprocess.py index a90e7f2..ffb0b69 100644 --- a/data/brats2021_seg/brats2021_raw_data/brats2021_seg_preprocess.py +++ b/data/brats2021_seg/brats2021_raw_data/brats2021_seg_preprocess.py @@ -7,9 +7,9 @@ from matplotlib import animation from monai.data import MetaTensor from multiprocessing import Process, Pool -from sklearn.preprocessing import MinMaxScaler +from sklearn.preprocessing import MinMaxScaler from monai.transforms import ( - Orientation, + Orientation, EnsureType, ConvertToMultiChannelBasedOnBratsClasses, ) @@ -51,11 +51,11 @@ def __init__( # MRI type self.MRI_TYPE = ["flair", "t1", "t1ce", "t2", "seg"] self.save_dir = save_dir - + def __len__(self): return self.case_name.__len__() - def get_modality_fp(self, case_name: str, mri_type: str)->str: + def get_modality_fp(self, case_name: str, mri_type: str) -> str: """ return the modality file path case_name: patient ID @@ -68,7 +68,7 @@ def get_modality_fp(self, case_name: str, mri_type: str)->str: ) return modality_fp - def load_nifti(self, fp)->list: + def load_nifti(self, fp) -> list: """ load a nifti file fp: path to the nifti file with (nii or nii.gz) extension @@ -80,7 +80,7 @@ def load_nifti(self, fp)->list: affine = nifti_data.affine return nifti_scan, affine - def normalize(self, x:np.ndarray)->np.ndarray: + def normalize(self, x: np.ndarray) -> np.ndarray: # Transform features by scaling each feature to a given range. scaler = MinMaxScaler(feature_range=(0, 1)) # (H, W, D) -> (H * W, D) @@ -97,29 +97,31 @@ def detach_meta(self, x: MetaTensor) -> np.ndarray: assert type(x) == MetaTensor return EnsureType(data_type="numpy", track_meta=False)(x) - def crop_brats2021_zero_pixels(self, x: np.ndarray)->np.ndarray: + def crop_brats2021_zero_pixels(self, x: np.ndarray) -> np.ndarray: # get rid of the zero pixels around mri scan and cut it so that the region is useful # crop (1, 240, 240, 155) to (1, 128, 128, 128) return x[:, 56:184, 56:184, 13:141] - def preprocess_brats_modality(self, data_fp: str, is_label: bool = False)->np.ndarray: + def preprocess_brats_modality( + self, data_fp: str, is_label: bool = False + ) -> np.ndarray: """ apply preprocess stage to the modality data_fp: directory to the modality """ data, affine = self.load_nifti(data_fp) - # label do not the be normalized + # label do not the be normalized if is_label: # Binary mask does not need to be float64! For saving storage purposes! data = data.astype(np.uint8) - # categorical -> one-hot-encoded + # categorical -> one-hot-encoded # (240, 240, 155) -> (3, 240, 240, 155) data = ConvertToMultiChannelBasedOnBratsClasses()(data) else: data = self.normalize(x=data) # (240, 240, 155) -> (1, 240, 240, 155) data = data[np.newaxis, ...] - + data = MetaTensor(x=data, affine=affine) # for oreinting the coordinate system we need the affine matrix data = self.orient(data) @@ -132,33 +134,33 @@ def preprocess_brats_modality(self, data_fp: str, is_label: bool = False)->np.nd def __getitem__(self, idx): case_name = self.case_name[idx] # e.g: train/BraTS2021_00000/BraTS2021_00000_flair.nii.gz - + # preprocess Flair modality FLAIR = self.get_modality_fp(case_name, self.MRI_TYPE[0]) flair = self.preprocess_brats_modality(data_fp=FLAIR, is_label=False) - flair_transv = flair.swapaxes(1, 3) # transverse plane - + flair_transv = flair.swapaxes(1, 3) # transverse plane + # # preprocess T1 modality # T1 = self.get_modality_fp(case_name, self.MRI_TYPE[1]) # t1 = self.preprocess_brats_modality(data_fp=T1, is_label=False) # t1_transv = t1.swapaxes(1, 3) # transverse plane - + # preprocess T1ce modality T1ce = self.get_modality_fp(case_name, self.MRI_TYPE[2]) t1ce = self.preprocess_brats_modality(data_fp=T1ce, is_label=False) - t1ce_transv = t1ce.swapaxes(1, 3) # transverse plane - + t1ce_transv = t1ce.swapaxes(1, 3) # transverse plane + # preprocess T2 T2 = self.get_modality_fp(case_name, self.MRI_TYPE[3]) t2 = self.preprocess_brats_modality(data_fp=T2, is_label=False) - t2_transv = t2.swapaxes(1, 3) # transverse plane - + t2_transv = t2.swapaxes(1, 3) # transverse plane + # preprocess segmentation label Label = self.get_modality_fp(case_name, self.MRI_TYPE[4]) label = self.preprocess_brats_modality(data_fp=Label, is_label=True) - label_transv = label.swapaxes(1, 3) # transverse plane + label_transv = label.swapaxes(1, 3) # transverse plane - # stack modalities along the first dimension + # stack modalities along the first dimension modalities = np.concatenate( (flair_transv, t1ce_transv, t2_transv), axis=0, @@ -174,7 +176,6 @@ def __call__(self): multi_p.join() print("finished preprocessing brats2021...") - def process(self, idx): if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) @@ -192,7 +193,6 @@ def process(self, idx): torch.save(label, label_fn) - def animate(input_1, input_2): """animate pairs of image sequences of the same length on two conjugate axis""" assert len(input_1) == len( @@ -220,7 +220,8 @@ def animate(input_1, input_2): repeat_delay=100, ) -def viz(volume_indx: int = 1, label_indx: int = 1)->None: + +def viz(volume_indx: int = 1, label_indx: int = 1) -> None: """ pair visualization of the volume and label volume_indx: index for the volume. ["flair", "t1", "t1ce", "t2"] @@ -236,13 +237,11 @@ def viz(volume_indx: int = 1, label_indx: int = 1)->None: if __name__ == "__main__": brats2021_task1_prep = Brats2021Task1Preprocess( - root_dir="./", - save_dir="../BraTS2021_Training_Data" - ) - # start preprocessing + root_dir="./", save_dir="../BraTS2021_Training_Data" + ) + # start preprocessing brats2021_task1_prep() - # visualization + # visualization # volume, label, _ = brats2021_task1_prep[100] # viz(volume_indx = 0, label_indx = 2) - diff --git a/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_kfold_csv.py b/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_kfold_csv.py index 77e2873..b935940 100644 --- a/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_kfold_csv.py +++ b/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_kfold_csv.py @@ -27,7 +27,7 @@ def save_pandas_df(dataframe: pd.DataFrame, save_path: str, header: list) -> Non """ assert save_path.endswith("csv") assert isinstance(dataframe, pd.DataFrame) - assert (dataframe.columns.__len__() == header.__len__()) + assert dataframe.columns.__len__() == header.__len__() dataframe.to_csv(path_or_buf=save_path, header=header, index=False) diff --git a/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv.py b/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv.py index 60e8014..3ca352a 100644 --- a/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv.py +++ b/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd + def create_pandas_df(data_dict: dict) -> pd.DataFrame: """ create a pandas dataframe out of data dictionary @@ -25,7 +26,7 @@ def save_pandas_df(dataframe: pd.DataFrame, save_path: str, header: list) -> Non """ assert save_path.endswith("csv") assert isinstance(dataframe, pd.DataFrame) - assert (dataframe.columns.__len__() == header.__len__()) + assert dataframe.columns.__len__() == header.__len__() dataframe.to_csv(path_or_buf=save_path, header=header, index=False) @@ -82,12 +83,11 @@ def create_train_val_test_csv_from_data_folder( train_sample_base_dir = np.array(data_dir)[train_idx] train_sample_case_name = np.array(case_name)[train_idx] - # we do not need test split so we can merge it with validation + # we do not need test split so we can merge it with validation val_idx = np.concatenate((val_idx, test_idx), axis=0) validation_sample_base_dir = np.array(data_dir)[val_idx] validation_sample_case_name = np.array(case_name)[val_idx] - # dictionary object to get converte to dataframe train_data = {"data_path": train_dp, "label_path": train_fold_cn} valid_data = {"data_path": valid_dp, "label_path": valid_fold_cn} @@ -106,21 +106,6 @@ def create_train_val_test_csv_from_data_folder( header=header, ) - - - - - - - - - - - - - - - # create a pandas data frame train_df = pd.DataFrame( data={"base_dir": train_sample_base_dir, "case_name": train_sample_case_name}, diff --git a/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv_v2.py b/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv_v2.py index d0bcf7b..2767ee8 100644 --- a/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv_v2.py +++ b/data/brats2021_seg/brats2021_raw_data/datameta_generator/create_train_val_test_csv_v2.py @@ -26,7 +26,7 @@ def save_pandas_df(dataframe: pd.DataFrame, save_path: str, header: list) -> Non """ assert save_path.endswith("csv") assert isinstance(dataframe, pd.DataFrame) - assert (dataframe.columns.__len__() == header.__len__()) + assert dataframe.columns.__len__() == header.__len__() dataframe.to_csv(path_or_buf=save_path, header=header, index=False) @@ -60,7 +60,7 @@ def create_train_val_test_csv_from_data_folder( # iterate through the folder to list all the filenames case_name = next(os.walk(folder_dir), (None, None, []))[1] - # appending append_dir to the case name based on the anatamical planes + # appending append_dir to the case name based on the anatamical planes planes = ["sagittal", "coronal", "transverse"] data_fp = [] label_fp = [] @@ -71,9 +71,11 @@ def create_train_val_test_csv_from_data_folder( # BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_sagittal_modalities.pt data_fp.append(os.path.join(append_dir, case, case_data).replace("\\", "/")) # BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_sagittal_label.pt - label_fp.append(os.path.join(append_dir, case, case_label).replace("\\", "/")) + label_fp.append( + os.path.join(append_dir, case, case_label).replace("\\", "/") + ) - # we have three anatomical plane for each cse + # we have three anatomical plane for each cse cropus_sample_count = case_name.__len__() * 3 idx = np.arange(0, cropus_sample_count) @@ -93,15 +95,20 @@ def create_train_val_test_csv_from_data_folder( train_sample_data_fp = np.array(data_fp)[train_idx] train_sample_label_fp = np.array(label_fp)[train_idx] - # we do not need test split so we can merge it with validation + # we do not need test split so we can merge it with validation val_idx = np.concatenate((val_idx, test_idx), axis=0) validation_sample_data_fp = np.array(data_fp)[val_idx] validation_sample_label_fp = np.array(label_fp)[val_idx] - # dictionary object to get converte to dataframe - train_data = {"data_path": train_sample_data_fp, "label_path": train_sample_label_fp} - valid_data = {"data_path": validation_sample_data_fp, "label_path": validation_sample_label_fp} + train_data = { + "data_path": train_sample_data_fp, + "label_path": train_sample_label_fp, + } + valid_data = { + "data_path": validation_sample_data_fp, + "label_path": validation_sample_label_fp, + } train_df = create_pandas_df(train_data) valid_df = create_pandas_df(valid_data) diff --git a/dataloaders/brats2017_seg.py b/dataloaders/brats2017_seg.py index a718057..3edb4a2 100644 --- a/dataloaders/brats2017_seg.py +++ b/dataloaders/brats2017_seg.py @@ -51,7 +51,10 @@ def __getitem__(self, idx): # load the preprocessed tensors volume = torch.load(volume_fp) label = torch.load(label_fp) - data = {"image": torch.from_numpy(volume).float(), "label": torch.from_numpy(label).float()} + data = { + "image": torch.from_numpy(volume).float(), + "label": torch.from_numpy(label).float(), + } if self.transform: data = self.transform(data) diff --git a/dataloaders/brats2021_seg.py b/dataloaders/brats2021_seg.py index ef3a256..16be619 100644 --- a/dataloaders/brats2021_seg.py +++ b/dataloaders/brats2021_seg.py @@ -1,51 +1,62 @@ -import os +import os import torch import pandas as pd -from torch.utils.data import Dataset +from torch.utils.data import Dataset + class Brats2021Task1Dataset(Dataset): - """ - Brats2021 task 1 dataset is the segmentation corpus of the data. This dataset class performs dataloading - on an already-preprocessed brats2021 data which has been resized, normalized and oriented in (Right, Anterior, Superior) format. - The csv file associated with the data has two columns: [data_path, case_name] + """ + Brats2021 task 1 dataset is the segmentation corpus of the data. This dataset class performs dataloading + on an already-preprocessed brats2021 data which has been resized, normalized and oriented in (Right, Anterior, Superior) format. + The csv file associated with the data has two columns: [data_path, case_name] MRI_TYPE are "FLAIR", "T1", "T1CE", "T2" and segmentation label is store separately - """ - def __init__(self, root_dir: str, is_train: bool = True, transform = None, fold_id: int = None): - """ - root_dir: path to (BraTS2021_Training_Data) folder - is_train: whether or nor it is train or validation - transform: composition of the pytorch transforms - fold_id: fold index in kfold dataheld out - """ - super().__init__() - if fold_id is not None: - csv_name = f"train_fold_{fold_id}.csv" if is_train else f"validation_fold_{fold_id}.csv" - csv_fp = os.path.join(root_dir, csv_name) - assert os.path.exists(csv_fp) - else: - csv_name = "train.csv" if is_train else "validation.csv" - csv_fp = os.path.join(root_dir, csv_name) - assert os.path.exists(csv_fp) - - self.csv = pd.read_csv(csv_fp) - self.transform = transform - - def __len__(self): - return self.csv.__len__() - - def __getitem__(self, idx): - data_path = self.csv["data_path"][idx] - case_name = self.csv["case_name"][idx] - # e.g, BraTS2021_00000_trnasverse_modalities.pt - # e.g, BraTS2021_00000_trnasverse_label.pt - volume_fp = os.path.join(data_path, f"{case_name}_modalities.pt") - label_fp = os.path.join(data_path, f"{case_name}_label.pt") - # load the preprocessed tensors - volume = torch.load(volume_fp) - label = torch.load(label_fp) - data = {"image": torch.from_numpy(volume).float(), "label": torch.from_numpy(label).float()} - - if self.transform: - data = self.transform(data) - - return data \ No newline at end of file + """ + + def __init__( + self, root_dir: str, is_train: bool = True, transform=None, fold_id: int = None + ): + """ + root_dir: path to (BraTS2021_Training_Data) folder + is_train: whether or nor it is train or validation + transform: composition of the pytorch transforms + fold_id: fold index in kfold dataheld out + """ + super().__init__() + if fold_id is not None: + csv_name = ( + f"train_fold_{fold_id}.csv" + if is_train + else f"validation_fold_{fold_id}.csv" + ) + csv_fp = os.path.join(root_dir, csv_name) + assert os.path.exists(csv_fp) + else: + csv_name = "train.csv" if is_train else "validation.csv" + csv_fp = os.path.join(root_dir, csv_name) + assert os.path.exists(csv_fp) + + self.csv = pd.read_csv(csv_fp) + self.transform = transform + + def __len__(self): + return self.csv.__len__() + + def __getitem__(self, idx): + data_path = self.csv["data_path"][idx] + case_name = self.csv["case_name"][idx] + # e.g, BraTS2021_00000_trnasverse_modalities.pt + # e.g, BraTS2021_00000_trnasverse_label.pt + volume_fp = os.path.join(data_path, f"{case_name}_modalities.pt") + label_fp = os.path.join(data_path, f"{case_name}_label.pt") + # load the preprocessed tensors + volume = torch.load(volume_fp) + label = torch.load(label_fp) + data = { + "image": torch.from_numpy(volume).float(), + "label": torch.from_numpy(label).float(), + } + + if self.transform: + data = self.transform(data) + + return data diff --git a/losses/losses.py b/losses/losses.py index b242f53..9571516 100644 --- a/losses/losses.py +++ b/losses/losses.py @@ -4,6 +4,7 @@ from typing import Dict from monai import losses + class CrossEntropyLoss(nn.Module): def __init__(self): super().__init__() @@ -23,6 +24,8 @@ def __init__(self): def forward(self, predictions, tragets): loss = self._loss(predictions, tragets) return loss + + ########################################################################### class DiceLoss(nn.Module): def __init__(self): @@ -58,6 +61,6 @@ def build_loss_fn(loss_type: str, loss_args: Dict = None): elif loss_type == "diceCE": return DiceCELoss() - + else: raise ValueError("must be cross entropy or soft dice loss for now!") diff --git a/metrics/segmentation_metrics.py b/metrics/segmentation_metrics.py index a630223..4d434ce 100644 --- a/metrics/segmentation_metrics.py +++ b/metrics/segmentation_metrics.py @@ -44,7 +44,7 @@ def __call__( # compute accuracy per channel acc = self.dice_metric.aggregate().cpu().numpy() avg_acc = acc.mean() - # To access individual metric + # To access individual metric # TC acc: acc[0] # WT acc: acc[1] # ET acc: acc[2]