From e41d2686426aceef3be17fa45d6f85f03069d503 Mon Sep 17 00:00:00 2001 From: Marc Adrian Peters Date: Sat, 20 Sep 2025 13:04:11 +0200 Subject: [PATCH] Add Attention U-Net --- docs/models.rst | 7 + segmentation_models_pytorch/__init__.py | 3 + segmentation_models_pytorch/base/modules.py | 70 ++++++- .../decoders/attentionunet/__init__.py | 3 + .../decoders/attentionunet/decoder.py | 183 ++++++++++++++++++ .../decoders/attentionunet/model.py | 148 ++++++++++++++ tests/models/base.py | 2 +- tests/models/test_attentionunet.py | 26 +++ 8 files changed, 438 insertions(+), 4 deletions(-) create mode 100644 segmentation_models_pytorch/decoders/attentionunet/__init__.py create mode 100644 segmentation_models_pytorch/decoders/attentionunet/decoder.py create mode 100644 segmentation_models_pytorch/decoders/attentionunet/model.py create mode 100644 tests/models/test_attentionunet.py diff --git a/docs/models.rst b/docs/models.rst index ab04bb5e..ea41583f 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -19,6 +19,13 @@ Unet++ .. autoclass:: segmentation_models_pytorch.UnetPlusPlus +.. _attentionunet: + +AttentionUnet +~~~~~~~~~~~~~ +.. autoclass:: segmentation_models_pytorch.AttentionUnet + + .. _fpn: FPN diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 37c64ef6..39be2d80 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -6,6 +6,7 @@ from .decoders.unet import Unet from .decoders.unetplusplus import UnetPlusPlus +from .decoders.attentionunet import AttentionUnet from .decoders.manet import MAnet from .decoders.linknet import Linknet from .decoders.fpn import FPN @@ -26,6 +27,7 @@ _MODEL_ARCHITECTURES = [ Unet, UnetPlusPlus, + AttentionUnet, MAnet, Linknet, FPN, @@ -76,6 +78,7 @@ def create_model( "losses", "metrics", "Unet", + "AttentionUnet", "UnetPlusPlus", "MAnet", "Linknet", diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index 15cfdb12..c9fbdc7b 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F try: from inplace_abn import InPlaceABN @@ -16,13 +17,13 @@ def get_norm_layer( # Step 1. Convert tot dict representation - ## Check boolean + # Check boolean if use_norm is True: norm_params = {"type": "batchnorm"} elif use_norm is False: norm_params = {"type": "identity"} - ## Check string + # Check string elif isinstance(use_norm, str): norm_str = use_norm.lower() if norm_str == "inplace": @@ -39,7 +40,7 @@ def get_norm_layer( f"{supported_norms}" ) - ## Check dict + # Check dict elif isinstance(use_norm, dict): norm_params = use_norm @@ -195,3 +196,66 @@ def __init__(self, name, **params): def forward(self, x): return self.attention(x) + + +class AttentionGate(nn.Module): + """ + Attention Gate as in Attention U-Net (Oktay et al., 2018). + + Reference: + https://arxiv.org/abs/1804.03999 + """ + + def __init__( + self, + in_channels, + gating_channels, + inter_channels=None, + upsample_mode="bilinear", + ): + super().__init__() + if inter_channels is None: + inter_channels = in_channels + + self.upsample_mode = upsample_mode + + # Downsample skip connection to match gating signal + self.theta = nn.Conv2d( + in_channels, inter_channels, kernel_size=1, stride=2, padding=0, bias=True + ) + self.W_g = nn.Sequential( + nn.Conv2d( + gating_channels, + inter_channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + nn.BatchNorm2d(inter_channels), + ) + self.psi = nn.Sequential( + nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0, bias=True), + nn.Sigmoid(), + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + # Downsample skip connection + theta_x = self.theta(x) + + # Transform gating signal + phi_g = self.W_g(g) + + if phi_g.shape[-2:] != theta_x.shape[-2:]: + phi_g = F.interpolate(phi_g, size=theta_x.shape[-2:], mode="bilinear") + + # Compute attention + f = self.relu(theta_x + phi_g) + alpha = self.psi(f) + + # Upsample attention to original skip connection size + alpha = F.interpolate(alpha, size=x.shape[2:], mode=self.upsample_mode) + + # Apply attention to skip connection + return x * alpha diff --git a/segmentation_models_pytorch/decoders/attentionunet/__init__.py b/segmentation_models_pytorch/decoders/attentionunet/__init__.py new file mode 100644 index 00000000..f21b867a --- /dev/null +++ b/segmentation_models_pytorch/decoders/attentionunet/__init__.py @@ -0,0 +1,3 @@ +from .model import AttentionUnet + +__all__ = ["AttentionUnet"] diff --git a/segmentation_models_pytorch/decoders/attentionunet/decoder.py b/segmentation_models_pytorch/decoders/attentionunet/decoder.py new file mode 100644 index 00000000..8b0e2e3b --- /dev/null +++ b/segmentation_models_pytorch/decoders/attentionunet/decoder.py @@ -0,0 +1,183 @@ +from typing import Any, Dict, List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class AttentionUnetDecoderBlock(nn.Module): + """Decoder block for Attention U-Net.""" + + def __init__( + self, + in_channels: int, + skip_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", + ): + super().__init__() + self.interpolation_mode = interpolation_mode + + self.attention_gate = None + if skip_channels > 0: + self.attention_gate = md.AttentionGate( + in_channels=skip_channels, + gating_channels=in_channels, + inter_channels=skip_channels, + upsample_mode="bilinear", + ) + + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + self.attention1 = md.Attention( + attention_type, in_channels=in_channels + skip_channels + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + self.attention2 = md.Attention(attention_type, in_channels=out_channels) + + def forward( + self, + feature_map: torch.Tensor, + target_height: int, + target_width: int, + skip_connection: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + feature_map = F.interpolate( + feature_map, + size=(target_height, target_width), + mode=self.interpolation_mode, + ) + if skip_connection is not None and self.attention_gate is not None: + skip_connection = self.attention_gate(skip_connection, feature_map) + feature_map = torch.cat([feature_map, skip_connection], dim=1) + feature_map = self.attention1(feature_map) + feature_map = self.conv1(feature_map) + feature_map = self.conv2(feature_map) + feature_map = self.attention2(feature_map) + return feature_map + + +class CenterBlock(nn.Sequential): + """Center block of the decoder. Applied to the last feature map of the encoder.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + super().__init__(conv1, conv2) + + +class AttentionUnetDecoder(nn.Module): + """The decoder part of the Attention U-Net architecture. + + Takes encoded features from different stages of the encoder and progressively upsamples them while + combining with gated skip connections. This helps preserve fine-grained details in the final segmentation. + """ + + def __init__( + self, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + add_center_block: bool = False, + interpolation_mode: str = "nearest", + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + if add_center_block: + self.center = CenterBlock( + head_channels, + head_channels, + use_norm=use_norm, + ) + else: + self.center = nn.Identity() + + # combine decoder keyword arguments + self.blocks = nn.ModuleList() + for block_in_channels, block_skip_channels, block_out_channels in zip( + in_channels, skip_channels, out_channels + ): + block = AttentionUnetDecoderBlock( + block_in_channels, + block_skip_channels, + block_out_channels, + use_norm=use_norm, + attention_type=attention_type, + interpolation_mode=interpolation_mode, + ) + self.blocks.append(block) + + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...] + spatial_shapes = [feature.shape[2:] for feature in features] + spatial_shapes = spatial_shapes[::-1] + + # remove first skip with same spatial resolution + features = features[1:] + # reverse channels to start from head of encoder + features = features[::-1] + + head = features[0] + skip_connections = features[1:] + + x = self.center(head) + + for i, decoder_block in enumerate(self.blocks): + # upsample to the next spatial shape + height, width = spatial_shapes[i + 1] + skip_connection = skip_connections[i] if i < len(skip_connections) else None + x = decoder_block(x, height, width, skip_connection=skip_connection) + + return x diff --git a/segmentation_models_pytorch/decoders/attentionunet/model.py b/segmentation_models_pytorch/decoders/attentionunet/model.py new file mode 100644 index 00000000..b9ed429f --- /dev/null +++ b/segmentation_models_pytorch/decoders/attentionunet/model.py @@ -0,0 +1,148 @@ +import warnings +from typing import Any, Dict, Optional, Union, Callable, Sequence + +from segmentation_models_pytorch.base import ( + ClassificationHead, + SegmentationHead, + SegmentationModel, +) +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading + +from .decoder import AttentionUnetDecoder + + +class AttentionUnet(SegmentationModel): + """ + Attention U-Net is a fully convolutional neural network architecture designed for semantic image segmentation. + It extends the original U-Net by incorporating attention gates in the decoder, which help the network focus on + relevant spatial regions while suppressing irrelevant background information. + + It consists of two main parts: + + 1. An encoder (downsampling path) that extracts increasingly abstract features + 2. A decoder (upsampling path) with attention gates that selectively emphasize informative features from the encoder + + The key is the use of attention-based skip connections between corresponding encoder and decoder layers. + These connections allow the decoder to access relevant fine-grained details from earlier encoder layers while + filtering out irrelevant information, which improves segmentation accuracy, particularly in complex or cluttered scenes. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` + decoder_attention_type: Attention module used after convolution in decoder of the model. + Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). + decoder_interpolation: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. Default is **None**. + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is built + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. + Keys with **None** values are pruned before passing. + + Returns: + ``torch.nn.Module``: **Attention U-Net** + + Reference: + https://arxiv.org/abs/1804.03999 + """ + + requires_divisible_input_shape = False + + @supports_config_loading + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + decoder_interpolation: str = "nearest", + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, Callable]] = None, + aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], + ): + super().__init__() + + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + **kwargs, + ) + + add_center_block = encoder_name.startswith("vgg") + + decoder_channels = decoder_channels[:encoder_depth] + self.decoder = AttentionUnetDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_norm=decoder_use_norm, + add_center_block=add_center_block, + attention_type=decoder_attention_type, + interpolation_mode=decoder_interpolation, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "attentionunet-{}".format(encoder_name) + self.initialize() diff --git a/tests/models/base.py b/tests/models/base.py index 2f317348..002584e5 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -96,7 +96,7 @@ def test_in_channels_and_depth_and_out_classes( ): kwargs = {} - if self.model_type in ["unet", "unetplusplus", "manet"]: + if self.model_type in ["unet", "attentionunet", "unetplusplus", "manet"]: kwargs = {"decoder_channels": self.decoder_channels[:depth]} if self.model_type == "dpt": diff --git a/tests/models/test_attentionunet.py b/tests/models/test_attentionunet.py new file mode 100644 index 00000000..cafe411d --- /dev/null +++ b/tests/models/test_attentionunet.py @@ -0,0 +1,26 @@ +import segmentation_models_pytorch as smp +from tests.models import base + + +class TestAttentionUnetModel(base.BaseModelTester): + test_model_type = "attentionunet" + files_for_diff = [r"decoders/attentionunet/", r"base/"] + + def test_interpolation(self): + # test bilinear + model_1 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bilinear", + ) + for block in model_1.decoder.blocks: + assert block.interpolation_mode == "bilinear" + + # test bicubic + model_2 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bicubic", + ) + for block in model_2.decoder.blocks: + assert block.interpolation_mode == "bicubic"