Source code for biapy.models.seunet

"""
This module implements the Squeeze-and-Excitation U-Net (SE-U-Net) architecture,a variant of the classic U-Net enhanced with Squeeze-and-Excitation (SE) blocks.

The SE-U-Net is designed for various image analysis tasks, particularly dense
prediction problems like image segmentation, in both 2D and 3D. It integrates
SE blocks into its convolutional layers to enable the network to perform
dynamic channel-wise feature recalibration, thereby improving the quality
of learned representations.

Key components and functionalities include:

Classes:

- ``SE_U_Net``: The main Squeeze-and-Excitation U-Net model, comprising an
  encoder (downsampling path), a decoder (upsampling path), and skip connections,
  all enhanced with SE blocks.

This module leverages several building blocks defined in `biapy.models.blocks`,
such as `DoubleConvBlock`, `UpBlock`, `ConvBlock`, `ProjectionHead`, and
utility functions for normalization (`get_norm_2d`, `get_norm_3d`).

Reference:
`Squeeze and Excitation Networks <https://openaccess.thecvf.com/content_cvpr_2018/html/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.html>`_.

Image representation:

.. image:: ../../img/models/unet.png
    :width: 100%
    :align: center

Image created with `PlotNeuralNet <https://github.com/HarisIqbal88/PlotNeuralNet>`_.
"""

import torch
import torch.nn as nn
from typing import Dict, List

from biapy.models.blocks import (
    ConvBlock,
    DoubleConvBlock, 
    UpBlock, 
    ConvBlock, 
    get_norm_2d, 
    get_norm_3d, 
    prepare_activation_layers, 
    init_weights
)
from biapy.models.heads import ProjectionHead


[docs] class SE_U_Net(nn.Module): """ Create 2D/3D U-Net with Squeeze-and-Excitation (SE) blocks. This model extends the classic U-Net architecture by incorporating Squeeze-and-Excitation (SE) modules within its convolutional blocks. This design aims to improve feature learning and propagation by allowing the network to perform dynamic channel-wise feature recalibration, leading to better performance in dense prediction tasks like image segmentation. Reference: `Squeeze and Excitation Networks <https://openaccess.thecvf.com/content_cvpr_2018/html/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.html>`_. """ def __init__( self, image_shape=(256, 256, 1), activation="ELU", feature_maps=[32, 64, 128, 256], drop_values=[0.1, 0.1, 0.1, 0.1], normalization="none", k_size=3, upsample_layer="convtranspose", yx_down=[2, 2, 2, 2], z_down=[2, 2, 2, 2], output_channels=[1], separated_decoders=False, output_channel_info=["F"], explicit_activations: bool = False, head_activations: List[str] = ["ce_sigmoid"], upsampling_factor=(), upsampling_position="pre", isotropy=False, larger_io=True, contrast: bool = False, contrast_proj_dim: int = 256, return_one_tensor: bool = False, ): """ Initialize the SE_U_Net model. Sets up the encoder (downsampling path), decoder (upsampling path), bottleneck, and optional super-resolution and multi-head output layers. It dynamically selects 2D or 3D convolutional, pooling, and normalization layers based on `ndim` and `isotropy` settings. Parameters ---------- image_shape : 3D/4D tuple Dimensions of the input image. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``. activation : str, optional Activation layer to be used throughout the model. feature_maps : array of ints, optional Feature maps to use on each level. drop_values : float, optional Dropout value to be fixed. normalization : str, optional Normalization layer (one of ``'bn'``, ``'sync_bn'`` ``'in'``, ``'gn'`` or ``'none'``). k_size : int, optional Kernel size. upsample_layer : str, optional Type of layer to use to make upsampling. Two options: "convtranspose" or "upsampling". z_down : List of ints, optional Downsampling used in z dimension. Set it to ``1`` if the dataset is not isotropic. yx_down : List of ints, optional Downsampling used in y and x dimensions. Set it to ``1`` if the dataset is not isotropic. output_channels : list of int, optional Output channels of the network. If one value is provided, the model will have a single output head. If two values are provided, the model will have two output heads (e.g. for multi-task learning with instance segmentation and classification). separated_decoders : bool, optional Whether to use separated decoders for each output head. output_channel_info : list of str, optional Information about the type of output channels. Possible values are: - "X": where X is a letter, e.g. "F" for foreground, "D" for distance, "R" for rays, "C" for cpntours, etc. - "class": classification (e.g. for multi-task learning) explicit_activations : bool, optional If True, uses explicit activation functions in the last layers. head_activations : List[str], optional Activation functions to apply to each output head if `explicit_activations` is True. upsampling_factor : tuple of ints, optional Factor of upsampling for super resolution workflow for each dimension. upsampling_position : str, optional Whether the upsampling is going to be made previously (``pre`` option) to the model or after the model (``post`` option). isotropy : bool or list of bool, optional Whether to use 3d or 2d convolutions at each U-Net level even if input is 3d. larger_io : bool, optional Whether to use extra and larger kernels in the input and output layers. contrast : bool, optional Whether to add contrastive learning head to the model. Default is ``False``. contrast_proj_dim : int, optional Dimension of the projection head for contrastive learning. Default is ``256``. return_one_tensor : bool, optional Whether to return a single tensor with all outputs concatenated (if False, returns a dictionary with separate entries). Default is ``False``. Returns ------- model : Torch model Residual U-Net model. Calling this function with its default parameters returns the following network: .. image:: ../../img/models/unet.png :width: 100% :align: center Image created with `PlotNeuralNet <https://github.com/HarisIqbal88/PlotNeuralNet>`_. """ super(SE_U_Net, self).__init__() if len(output_channels) == 0: raise ValueError("'output_channels' needs to has at least one value") if contrast and len(output_channels) > 2: raise ValueError("If 'contrast' is True, 'output_channels' can only have two values at max: one for the main output and one for the class.") print("Selected output channels:") for i, info in enumerate(output_channel_info): print(f" - {i} channel for {info} output") self.depth = len(feature_maps) - 1 self.ndim = 3 if len(image_shape) == 4 else 2 self.z_down = z_down self.yx_down = yx_down self.output_channels = output_channels self.output_channel_info = output_channel_info self.return_class = True if "class" in output_channel_info else False self.contrast = contrast self.explicit_activations = explicit_activations self.return_one_tensor = return_one_tensor if self.explicit_activations: assert len(head_activations) == sum(output_channels), "If 'explicit_activations' is True, 'head_activations' needs to " "have the same number of values as 'output_channels'" self.head_activations, self.class_head_activations = prepare_activation_layers(head_activations, output_channel_info, output_channels) if self.return_class and self.class_head_activations is None: raise ValueError("If 'return_class' is True, 'head_activations' must be provided.") if type(isotropy) == bool: isotropy = [isotropy] * len(feature_maps) if self.ndim == 3: conv = nn.Conv3d convtranspose = nn.ConvTranspose3d pooling = nn.MaxPool3d norm_func = get_norm_3d dropout = nn.Dropout3d else: conv = nn.Conv2d convtranspose = nn.ConvTranspose2d pooling = nn.MaxPool2d norm_func = get_norm_2d dropout = nn.Dropout2d # Super-resolution self.pre_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "pre": self.pre_upsampling = convtranspose( image_shape[-1], image_shape[-1], kernel_size=upsampling_factor, stride=upsampling_factor, ) # ENCODER self.down_path = nn.ModuleList() self.mpooling_layers = nn.ModuleList() in_channels = image_shape[-1] # extra (larger) input layer if larger_io: kernel_size = (k_size + 2, k_size + 2) if self.ndim == 2 else (k_size + 2, k_size + 2, k_size + 2) if not isotropy[0] and self.ndim == 3: kernel_size = (1, k_size + 2, k_size + 2) self.conv_in = ConvBlock( conv=conv, in_size=in_channels, out_size=feature_maps[0], k_size=kernel_size, act=activation, norm=normalization, ) in_channels = feature_maps[0] else: self.conv_in = None for i in range(self.depth): kernel_size = (k_size, k_size) if self.ndim == 2 else (k_size, k_size, k_size) if not isotropy[i] and self.ndim == 3: kernel_size = (1, k_size, k_size) self.down_path.append( DoubleConvBlock( conv=conv, in_size=in_channels, out_size=feature_maps[i], k_size=kernel_size, act=activation, norm=normalization, dropout=drop_values[i], se_block=True, ) ) mpool = (z_down[i], yx_down[i], yx_down[i]) if self.ndim == 3 else (yx_down[i], yx_down[i]) self.mpooling_layers.append(pooling(mpool)) in_channels = feature_maps[i] kernel_size = (k_size, k_size) if self.ndim == 2 else (k_size, k_size, k_size) if not isotropy[-1] and self.ndim == 3: kernel_size = (1, k_size, k_size) self.bottleneck = DoubleConvBlock( conv=conv, in_size=in_channels, out_size=feature_maps[-1], k_size=kernel_size, act=activation, norm=normalization, dropout=drop_values[-1], se_block=True, ) # DECODER self.num_decoders = 1 if not separated_decoders else len(output_channels) self.up_paths = nn.ModuleList([nn.ModuleList() for _ in range(self.num_decoders)]) for j in range(self.num_decoders): in_channels = feature_maps[-1] for i in range(self.depth - 1, -1, -1): kernel_size = (k_size, k_size) if self.ndim == 2 else (k_size, k_size, k_size) if not isotropy[i] and self.ndim == 3: kernel_size = (1, k_size, k_size) self.up_paths[j].append( UpBlock( ndim=self.ndim, convtranspose=convtranspose, in_size=in_channels, out_size=feature_maps[i], z_down=z_down[i], yx_down=yx_down[i], up_mode=upsample_layer, conv=conv, k_size=kernel_size, act=activation, norm=normalization, dropout=drop_values[i], se_block=True, ) ) # type: ignore in_channels = feature_maps[i] # extra (larger) output layer if larger_io: kernel_size = (k_size + 2, k_size + 2) if self.ndim == 2 else (k_size + 2, k_size + 2, k_size + 2) if not isotropy[0] and self.ndim == 3: kernel_size = (1, k_size + 2, k_size + 2) self.conv_out = nn.ModuleList([ ConvBlock( conv=conv, in_size=feature_maps[0], out_size=feature_maps[0], k_size=kernel_size, act=activation, norm=normalization, ) for _ in range(self.num_decoders) ]) else: self.conv_out = None # Super-resolution self.post_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "post": self.post_upsampling = convtranspose( feature_maps[0], feature_maps[0], kernel_size=upsampling_factor, stride=upsampling_factor, ) if self.contrast: # extra added layers self.heads = nn.Sequential( conv(feature_maps[0], feature_maps[0], kernel_size=3, stride=1, padding=1), norm_func(normalization, feature_maps[0]), dropout(0.10), conv(feature_maps[0], output_channels[0], kernel_size=1, stride=1, padding=0, bias=False), ) self.proj_head = ProjectionHead(ndim=self.ndim, in_channels=feature_maps[0], proj_dim=contrast_proj_dim) else: self.heads = nn.Sequential() for i, out_ch in enumerate(output_channels): self.heads.append(conv(feature_maps[0], out_ch, kernel_size=1, padding="same")) init_weights(self)
[docs] def forward(self, x) -> Dict | torch.Tensor: """ Forward pass of the model. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels, height, width) for 2D or (batch_size, channels, depth, height, width) for 3D. Returns ------- Dict or torch.Tensor Model output. Returns a dictionary if multi-head or contrastive outputs are enabled, otherwise returns the main prediction tensor. """ # Super-resolution if self.pre_upsampling: x = self.pre_upsampling(x) # extra large-kernel input layer if self.conv_in: x = self.conv_in(x) # Encoder blocks = [] for i, layers in enumerate(zip(self.down_path, self.mpooling_layers)): down, pool = layers x = down(x) blocks.append(x) x = pool(x) x_bot = self.bottleneck(x) # Decoder feats = [] for j in range(self.num_decoders): x = x_bot for i, up in enumerate(self.up_paths[j]): x = up(x, blocks[-i - 1]) feats.append(x) # extra large-kernel output layer if self.conv_out: for j in range(self.num_decoders): feats[j] = self.conv_out[j](feats[j]) # Super-resolution if self.post_upsampling: feats[0] = self.post_upsampling(feats[0]) out_dict = {} # Pass the features through the output heads class_outs, outs = [], [] for i, head in enumerate(self.heads): feat = feats[i] if self.num_decoders > 1 else feats[0] if "class" in self.output_channel_info[i]: class_outs.append(head(feat)) else: outs.append(head(feat)) outs = torch.cat(outs, dim=1) # Apply activations to the output heads if explicit_activations is True if self.explicit_activations: # If there is only one activation, apply it to the whole tensor if len(self.head_activations) == 1: outs = self.head_activations[0](outs) else: for i, act in enumerate(self.head_activations): outs[:, i:i+1] = act(outs[:, i:i+1]) if self.return_class and self.class_head_activations is not None: for i, act in enumerate(self.class_head_activations): class_outs[i] = act(class_outs[i]) out_dict = { "pred": outs, } if self.return_class: out_dict["class"] = torch.cat(class_outs, dim=1) # Contrastive learning head if self.contrast: out_dict["embed"] = self.proj_head(feats[0]) if len(out_dict.keys()) == 1: return out_dict["pred"] else: if self.return_one_tensor: if "class" in out_dict: return torch.cat((out_dict["pred"], torch.argmax(out_dict["class"], dim=1).unsqueeze(1)), dim=1) else: return out_dict["pred"] return out_dict