Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ Unet++
.. autoclass:: segmentation_models_pytorch.UnetPlusPlus


.. _attentionunet:

AttentionUnet
~~~~~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.AttentionUnet


.. _fpn:

FPN
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .decoders.unet import Unet
from .decoders.unetplusplus import UnetPlusPlus
from .decoders.attentionunet import AttentionUnet
from .decoders.manet import MAnet
from .decoders.linknet import Linknet
from .decoders.fpn import FPN
Expand All @@ -26,6 +27,7 @@
_MODEL_ARCHITECTURES = [
Unet,
UnetPlusPlus,
AttentionUnet,
MAnet,
Linknet,
FPN,
Expand Down Expand Up @@ -76,6 +78,7 @@ def create_model(
"losses",
"metrics",
"Unet",
"AttentionUnet",
"UnetPlusPlus",
"MAnet",
"Linknet",
Expand Down
70 changes: 67 additions & 3 deletions segmentation_models_pytorch/base/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
from inplace_abn import InPlaceABN
Expand All @@ -16,13 +17,13 @@ def get_norm_layer(

# Step 1. Convert tot dict representation

## Check boolean
# Check boolean
if use_norm is True:
norm_params = {"type": "batchnorm"}
elif use_norm is False:
norm_params = {"type": "identity"}

## Check string
# Check string
elif isinstance(use_norm, str):
norm_str = use_norm.lower()
if norm_str == "inplace":
Expand All @@ -39,7 +40,7 @@ def get_norm_layer(
f"{supported_norms}"
)

## Check dict
# Check dict
elif isinstance(use_norm, dict):
norm_params = use_norm

Expand Down Expand Up @@ -195,3 +196,66 @@ def __init__(self, name, **params):

def forward(self, x):
return self.attention(x)


class AttentionGate(nn.Module):
"""
Attention Gate as in Attention U-Net (Oktay et al., 2018).

Reference:
https://arxiv.org/abs/1804.03999
"""

def __init__(
self,
in_channels,
gating_channels,
inter_channels=None,
upsample_mode="bilinear",
):
super().__init__()
if inter_channels is None:
inter_channels = in_channels

self.upsample_mode = upsample_mode

# Downsample skip connection to match gating signal
self.theta = nn.Conv2d(
in_channels, inter_channels, kernel_size=1, stride=2, padding=0, bias=True
)
self.W_g = nn.Sequential(
nn.Conv2d(
gating_channels,
inter_channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
),
nn.BatchNorm2d(inter_channels),
)
self.psi = nn.Sequential(
nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.Sigmoid(),
)
self.relu = nn.ReLU(inplace=True)

def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
# Downsample skip connection
theta_x = self.theta(x)

# Transform gating signal
phi_g = self.W_g(g)

if phi_g.shape[-2:] != theta_x.shape[-2:]:
phi_g = F.interpolate(phi_g, size=theta_x.shape[-2:], mode="bilinear")

# Compute attention
f = self.relu(theta_x + phi_g)
alpha = self.psi(f)

# Upsample attention to original skip connection size
alpha = F.interpolate(alpha, size=x.shape[2:], mode=self.upsample_mode)

# Apply attention to skip connection
return x * alpha
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import AttentionUnet

__all__ = ["AttentionUnet"]
183 changes: 183 additions & 0 deletions segmentation_models_pytorch/decoders/attentionunet/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class AttentionUnetDecoderBlock(nn.Module):
"""Decoder block for Attention U-Net."""

def __init__(
self,
in_channels: int,
skip_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
attention_type: Optional[str] = None,
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolation_mode = interpolation_mode

self.attention_gate = None
if skip_channels > 0:
self.attention_gate = md.AttentionGate(
in_channels=skip_channels,
gating_channels=in_channels,
inter_channels=skip_channels,
upsample_mode="bilinear",
)

self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_norm=use_norm,
)
self.attention1 = md.Attention(
attention_type, in_channels=in_channels + skip_channels
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_norm=use_norm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)

def forward(
self,
feature_map: torch.Tensor,
target_height: int,
target_width: int,
skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
feature_map = F.interpolate(
feature_map,
size=(target_height, target_width),
mode=self.interpolation_mode,
)
if skip_connection is not None and self.attention_gate is not None:
skip_connection = self.attention_gate(skip_connection, feature_map)
feature_map = torch.cat([feature_map, skip_connection], dim=1)
feature_map = self.attention1(feature_map)
feature_map = self.conv1(feature_map)
feature_map = self.conv2(feature_map)
feature_map = self.attention2(feature_map)
return feature_map


class CenterBlock(nn.Sequential):
"""Center block of the decoder. Applied to the last feature map of the encoder."""

def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_norm=use_norm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_norm=use_norm,
)
super().__init__(conv1, conv2)


class AttentionUnetDecoder(nn.Module):
"""The decoder part of the Attention U-Net architecture.

Takes encoded features from different stages of the encoder and progressively upsamples them while
combining with gated skip connections. This helps preserve fine-grained details in the final segmentation.
"""

def __init__(
self,
encoder_channels: Sequence[int],
decoder_channels: Sequence[int],
n_blocks: int = 5,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
attention_type: Optional[str] = None,
add_center_block: bool = False,
interpolation_mode: str = "nearest",
):
super().__init__()

if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)

# remove first skip with same spatial resolution
encoder_channels = encoder_channels[1:]
# reverse channels to start from head of encoder
encoder_channels = encoder_channels[::-1]

# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

if add_center_block:
self.center = CenterBlock(
head_channels,
head_channels,
use_norm=use_norm,
)
else:
self.center = nn.Identity()

# combine decoder keyword arguments
self.blocks = nn.ModuleList()
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = AttentionUnetDecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_norm=use_norm,
attention_type=attention_type,
interpolation_mode=interpolation_mode,
)
self.blocks.append(block)

def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
spatial_shapes = [feature.shape[2:] for feature in features]
spatial_shapes = spatial_shapes[::-1]

# remove first skip with same spatial resolution
features = features[1:]
# reverse channels to start from head of encoder
features = features[::-1]

head = features[0]
skip_connections = features[1:]

x = self.center(head)

for i, decoder_block in enumerate(self.blocks):
# upsample to the next spatial shape
height, width = spatial_shapes[i + 1]
skip_connection = skip_connections[i] if i < len(skip_connections) else None
x = decoder_block(x, height, width, skip_connection=skip_connection)

return x
Loading