Source code for biapy.models.multiresunet

"""
This module implements the MultiResUNet architecture, a U-Net variant designed for multimodal biomedical image segmentation.

MultiResUNet enhances the standard U-Net by incorporating "MultiRes Blocks"
in the encoder and decoder paths, and "ResPaths" for skip connections.
These components aim to improve feature representation and information flow
across different scales.

Key components implemented in this file include:

Classes:

- ``Conv_batchnorm``: A basic convolutional block with optional batch normalization and activation.
- ``Multiresblock``: The MultiRes Block, a core component that processes features
  through parallel convolutional paths of different kernel sizes (3x3, 5x5, 7x7)
  and fuses them.
- ``Respath``: The ResPath module, which acts as an enhanced skip connection,
  applying residual convolutional blocks to features before they are concatenated
  in the decoder.
- ``MultiResUnet``: The main MultiResUNet model, combining the encoder, decoder,
  and skip connections using the MultiRes Blocks and ResPaths.

The implementation supports both 2D and 3D inputs, various normalization types,
and optional multi-head outputs, including a contrastive learning projection.

Reference:
`MultiResUNet : Rethinking the U-Net Architecture for Multimodal Biomedical Image
Segmentation <https://arxiv.org/abs/1902.04049>`_

Code Adapted From:
https://github.com/nibtehaz/MultiResUNet
"""
import torch
import torch.nn as nn
from typing import Dict, List

from biapy.models.heads import ProjectionHead
from biapy.models.blocks import prepare_activation_layers


[docs] class Conv_batchnorm(torch.nn.Module): """ A basic convolutional block with optional batch normalization and activation. This module combines a convolutional layer, a batch normalization layer, and an activation function (ReLU by default), providing a standard building block for neural networks. Parameters ---------- conv : Torch conv layer Convolutional layer type to use (e.g., `nn.Conv2d`, `nn.Conv3d`). batchnorm : Torch batch normalization layer Batch normalization layer type to use (e.g., `nn.BatchNorm2d`, `nn.BatchNorm3d`). num_in_filters : int Number of input channels for the convolutional layer. num_out_filters : int Number of output channels for the convolutional layer. kernel_size : Tuple of ints Size of the convolving kernel (e.g., 3 for 3x3, (3,3,3) for 3x3x3). stride : Tuple of ints, optional Stride of the convolution. Defaults to 1. activation : str, optional Activation function to apply after convolution and batch normalization. Currently supports "relu" or "None" (for no activation). Defaults to "relu". """ def __init__( self, conv, batchnorm, num_in_filters, num_out_filters, kernel_size, stride=1, activation="relu", ): """ Initialize the Conv_batchnorm block. Sets up the convolutional layer, batch normalization layer, and stores the chosen activation function type. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use (e.g., `nn.Conv2d` for 2D, `nn.Conv3d` for 3D). batchnorm : Type[nn.BatchNorm2d | nn.BatchNorm3d] The batch normalization layer type to use (e.g., `nn.BatchNorm2d`, `nn.BatchNorm3d`). num_in_filters : int The number of input channels for the convolutional layer. num_out_filters : int The number of output channels for the convolutional layer. kernel_size : int | Tuple[int, ...] The size of the convolving kernel. Can be a single integer for square/cubic kernels or a tuple for specific dimensions. stride : int | Tuple[int, ...], optional The stride of the convolution. Can be a single integer or a tuple. Defaults to 1. activation : str, optional The name of the activation function to apply after convolution and batch normalization. Currently supports "relu" or "None" (for no activation). Defaults to "relu". """ super().__init__() self.activation = activation self.conv1 = conv( in_channels=num_in_filters, out_channels=num_out_filters, kernel_size=kernel_size, stride=stride, padding="same", ) self.batchnorm = batchnorm(num_out_filters)
[docs] def forward(self, x): """ Perform the forward pass of the convolutional block. Applies convolution, batch normalization, and then the specified activation function to the input tensor. Parameters ---------- x : torch.Tensor The input tensor to the block. Returns ------- torch.Tensor The output tensor after convolution, batch normalization, and activation. """ x = self.conv1(x) x = self.batchnorm(x) if self.activation == "relu": return torch.nn.functional.relu(x) else: return x
[docs] class Multiresblock(torch.nn.Module): """ MultiRes Block as described in the MultiResUNet paper. This block enhances feature extraction by processing input through parallel convolutional paths with different effective receptive fields (3x3, 5x5, 7x7) and then concatenating their outputs. It also includes a shortcut connection and batch normalization. Parameters ---------- conv : Torch conv layer Convolutional layer type to use (e.g., `nn.Conv2d`, `nn.Conv3d`). batchnorm : Torch batch normalization layer Batch normalization layer type to use (e.g., `nn.BatchNorm2d`, `nn.BatchNorm3d`). num_in_channels : int Number of input channels coming into the MultiRes Block. num_filters : int Base number of filters for calculating the output filter counts for the internal convolutional paths. alpha : float, optional Alpha hyperparameter (default: 1.67). Used to scale the total number of filters, influencing the capacity of the block. """ def __init__(self, conv, batchnorm, num_in_channels, num_filters, alpha=1.67): """ Initialize the MultiRes Block. Calculates the number of filters for each parallel convolutional path (3x3, 5x5, 7x7) based on `num_filters` and `alpha`. It then sets up these parallel paths using `Conv_batchnorm` blocks, along with a shortcut connection and final batch normalization layers. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use. batchnorm : Type[nn.BatchNorm2d | nn.BatchNorm3d] The batch normalization layer type to use. num_in_channels : int The number of input channels for the MultiRes Block. num_filters : int The base number of filters used to determine the output channel counts for the internal convolutional paths. alpha : float, optional The scaling factor for the total number of filters (`W`). Defaults to 1.67. """ super().__init__() self.alpha = alpha self.W = num_filters * alpha filt_cnt_3x3 = int(self.W * 0.167) filt_cnt_5x5 = int(self.W * 0.333) filt_cnt_7x7 = int(self.W * 0.5) num_out_filters = filt_cnt_3x3 + filt_cnt_5x5 + filt_cnt_7x7 self.shortcut = Conv_batchnorm( conv, batchnorm, num_in_channels, num_out_filters, kernel_size=1, activation="None", ) self.conv_3x3 = Conv_batchnorm( conv, batchnorm, num_in_channels, filt_cnt_3x3, kernel_size=3, activation="relu", ) self.conv_5x5 = Conv_batchnorm( conv, batchnorm, filt_cnt_3x3, filt_cnt_5x5, kernel_size=3, activation="relu", ) self.conv_7x7 = Conv_batchnorm( conv, batchnorm, filt_cnt_5x5, filt_cnt_7x7, kernel_size=3, activation="relu", ) self.batch_norm1 = batchnorm(num_out_filters) self.batch_norm2 = batchnorm(num_out_filters)
[docs] def forward(self, x): """ Perform the forward pass of the MultiRes Block. The input `x` first goes through a shortcut connection. Simultaneously, it passes through three sequential convolutional paths (3x3, 5x5, 7x7). The outputs of these paths are concatenated, batch normalized, and then added to the shortcut output. A final batch normalization and ReLU activation are applied. Parameters ---------- x : torch.Tensor The input tensor to the MultiRes Block. Returns ------- torch.Tensor The output tensor of the MultiRes Block. """ shrtct = self.shortcut(x) a = self.conv_3x3(x) b = self.conv_5x5(a) c = self.conv_7x7(b) x = torch.cat([a, b, c], dim=1) x = self.batch_norm1(x) x = x + shrtct x = self.batch_norm2(x) x = torch.nn.functional.relu(x) return x
[docs] class Respath(torch.nn.Module): """ ResPath module for MultiResUNet. ResPath acts as an enhanced skip connection by applying a series of residual convolutional blocks to the features before they are concatenated into the decoder path. This helps to reduce the semantic gap between encoder and decoder features. Parameters ---------- conv : Torch conv layer Convolutional layer type to use (e.g., `nn.Conv2d`, `nn.Conv3d`). batchnorm : Torch batch normalization layer Batch normalization layer type to use (e.g., `nn.BatchNorm2d`, `nn.BatchNorm3d`). num_in_filters : int Number of input channels coming into the ResPath. num_out_filters : int Number of output channels for each convolutional block within the ResPath. respath_length : int The number of residual convolutional blocks to stack in the ResPath. """ def __init__(self, conv, batchnorm, num_in_filters, num_out_filters, respath_length): """ Initialize the ResPath module. Sets up a sequence of `respath_length` residual convolutional blocks. Each block consists of a convolutional layer and a shortcut connection, followed by batch normalization and ReLU. The input and output filter counts are managed across these blocks. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use. batchnorm : Type[nn.BatchNorm2d | nn.BatchNorm3d] The batch normalization layer type to use. num_in_filters : int The number of input channels for the first block in the ResPath. num_out_filters : int The number of output channels for each convolutional block within the ResPath. respath_length : int The number of residual convolutional blocks to stack in this ResPath. """ super().__init__() self.respath_length = respath_length self.shortcuts = torch.nn.ModuleList([]) self.convs = torch.nn.ModuleList([]) self.bns = torch.nn.ModuleList([]) for i in range(self.respath_length): if i == 0: self.shortcuts.append( Conv_batchnorm( conv, batchnorm, num_in_filters, num_out_filters, kernel_size=1, activation="None", ) ) self.convs.append( Conv_batchnorm( conv, batchnorm, num_in_filters, num_out_filters, kernel_size=3, activation="relu", ) ) else: self.shortcuts.append( Conv_batchnorm( conv, batchnorm, num_out_filters, num_out_filters, kernel_size=1, activation="None", ) ) self.convs.append( Conv_batchnorm( conv, batchnorm, num_out_filters, num_out_filters, kernel_size=3, activation="relu", ) ) self.bns.append(batchnorm(num_out_filters))
[docs] def forward(self, x): """ Perform the forward pass of the ResPath. The input `x` passes through a sequence of `respath_length` residual convolutional blocks. Each block involves a convolutional path and a shortcut connection, followed by batch normalization and ReLU activation. Parameters ---------- x : torch.Tensor The input tensor to the ResPath. Returns ------- torch.Tensor The output tensor after processing through all residual blocks in the ResPath. """ for short, conv, bn in zip(self.shortcuts, self.convs, self.bns): shortcut = short(x) x = conv(x) x = bn(x) x = torch.nn.functional.relu(x) x = x + shortcut x = bn(x) x = torch.nn.functional.relu(x) return x
[docs] class MultiResUnet(torch.nn.Module): def __init__( self, ndim, input_channels, alpha=1.67, z_down=[2, 2, 2, 2], output_channels=[1], output_channel_info=["F"], explicit_activations: bool = False, head_activations: List[str] = ["ce_sigmoid"], upsampling_factor=(), upsampling_position="pre", contrast: bool = False, contrast_proj_dim: int = 256, return_one_tensor: bool = False, ): """ Create 2D/3D MultiResUNet model. Reference: `MultiResUNet : Rethinking the U-Net Architecture for Multimodal Biomedical Image Segmentation <https://arxiv.org/abs/1902.04049>`_. Parameters ---------- ndim : int Number of dimensions of the input data. input_channels: int Number of channels in image. alpha: float, optional Alpha hyperparameter (default: 1.67) z_down : List of ints, optional Downsampling used in z dimension. Set it to ``1`` if the dataset is not isotropic. output_channels : list of int, optional Output channels of the network. If one value is provided, the model will have a single output head. If two values are provided, the model will have two output heads (e.g. for multi-task learning with instance segmentation and classification). output_channel_info : list of str, optional Information about the type of output channels. Possible values are: - "X": where X is a letter, e.g. "F" for foreground, "D" for distance, "R" for rays, "C" for cpntours, etc. - "class": classification (e.g. for multi-task learning) explicit_activations : bool, optional If True, uses explicit activation functions in the last layers. head_activations : List[str], optional Activation functions to apply to each output head if `explicit_activations` is True. upsampling_factor : tuple of ints, optional Factor of upsampling for super resolution workflow for each dimension. upsampling_position : str, optional Whether the upsampling is going to be made previously (``pre`` option) to the model or after the model (``post`` option). contrast : bool, optional Whether to add contrastive learning head to the model. Default is ``False``. contrast_proj_dim : int, optional Dimension of the projection head for contrastive learning. Default is ``256``. return_one_tensor : bool, optional Whether to return a single tensor with all outputs concatenated (if False, returns a dictionary with separate entries). Default is ``False``. Raises ------ ValueError If 'output_channels' is empty or has more than two values. """ super().__init__() if len(output_channels) == 0: raise ValueError("'output_channels' needs to has at least one value") if contrast and len(output_channels) > 2: raise ValueError("If 'contrast' is True, 'output_channels' can only have two values at max: one for the main output and one for the class.") print("Selected output channels:") for i, info in enumerate(output_channel_info): print(f" - {i} channel for {info} output") self.ndim = ndim self.alpha = alpha self.output_channels = output_channels self.output_channel_info = output_channel_info self.return_class = True if "class" in output_channel_info else False self.contrast = contrast self.explicit_activations = explicit_activations self.return_one_tensor = return_one_tensor if self.explicit_activations: assert len(head_activations) == sum(output_channels), "If 'explicit_activations' is True, 'head_activations' needs to " "have the same number of values as 'output_channels'" self.head_activations, self.class_head_activations = prepare_activation_layers(head_activations, output_channel_info, output_channels) if self.return_class and self.class_head_activations is None: raise ValueError("If 'return_class' is True, 'head_activations' must be provided.") if self.ndim == 3: conv = nn.Conv3d convtranspose = nn.ConvTranspose3d batchnorm_layer = nn.BatchNorm3d pooling = nn.MaxPool3d dropout = nn.Dropout3d else: conv = nn.Conv2d convtranspose = nn.ConvTranspose2d batchnorm_layer = nn.BatchNorm2d pooling = nn.MaxPool2d dropout = nn.Dropout2d # Super-resolution self.pre_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "pre": self.pre_upsampling = convtranspose( input_channels, input_channels, kernel_size=upsampling_factor, stride=upsampling_factor, ) # Encoder Path self.multiresblock1 = Multiresblock(conv, batchnorm_layer, input_channels, 32) self.in_filters1 = int(32 * self.alpha * 0.167) + int(32 * self.alpha * 0.333) + int(32 * self.alpha * 0.5) mpool = (z_down[0], 2, 2) if self.ndim == 3 else (2, 2) self.pool1 = pooling(mpool) self.respath1 = Respath(conv, batchnorm_layer, self.in_filters1, 32, respath_length=4) self.multiresblock2 = Multiresblock(conv, batchnorm_layer, self.in_filters1, 32 * 2) self.in_filters2 = ( int(32 * 2 * self.alpha * 0.167) + int(32 * 2 * self.alpha * 0.333) + int(32 * 2 * self.alpha * 0.5) ) mpool = (z_down[1], 2, 2) if self.ndim == 3 else (2, 2) self.pool2 = pooling(mpool) self.respath2 = Respath(conv, batchnorm_layer, self.in_filters2, 32 * 2, respath_length=3) self.multiresblock3 = Multiresblock(conv, batchnorm_layer, self.in_filters2, 32 * 4) self.in_filters3 = ( int(32 * 4 * self.alpha * 0.167) + int(32 * 4 * self.alpha * 0.333) + int(32 * 4 * self.alpha * 0.5) ) mpool = (z_down[2], 2, 2) if self.ndim == 3 else (2, 2) self.pool3 = pooling(mpool) self.respath3 = Respath(conv, batchnorm_layer, self.in_filters3, 32 * 4, respath_length=2) self.multiresblock4 = Multiresblock(conv, batchnorm_layer, self.in_filters3, 32 * 8) self.in_filters4 = ( int(32 * 8 * self.alpha * 0.167) + int(32 * 8 * self.alpha * 0.333) + int(32 * 8 * self.alpha * 0.5) ) mpool = (z_down[3], 2, 2) if self.ndim == 3 else (2, 2) self.pool4 = pooling(mpool) self.respath4 = Respath(conv, batchnorm_layer, self.in_filters4, 32 * 8, respath_length=1) self.multiresblock5 = Multiresblock(conv, batchnorm_layer, self.in_filters4, 32 * 16) self.in_filters5 = ( int(32 * 16 * self.alpha * 0.167) + int(32 * 16 * self.alpha * 0.333) + int(32 * 16 * self.alpha * 0.5) ) # Decoder path mpool = (z_down[3], 2, 2) if self.ndim == 3 else (2, 2) self.upsample6 = convtranspose(self.in_filters5, 32 * 8, kernel_size=mpool, stride=mpool) self.concat_filters1 = 32 * 8 * 2 self.multiresblock6 = Multiresblock(conv, batchnorm_layer, self.concat_filters1, 32 * 8) self.in_filters6 = ( int(32 * 8 * self.alpha * 0.167) + int(32 * 8 * self.alpha * 0.333) + int(32 * 8 * self.alpha * 0.5) ) mpool = (z_down[2], 2, 2) if self.ndim == 3 else (2, 2) self.upsample7 = convtranspose(self.in_filters6, 32 * 4, kernel_size=mpool, stride=mpool) self.concat_filters2 = 32 * 4 * 2 self.multiresblock7 = Multiresblock(conv, batchnorm_layer, self.concat_filters2, 32 * 4) self.in_filters7 = ( int(32 * 4 * self.alpha * 0.167) + int(32 * 4 * self.alpha * 0.333) + int(32 * 4 * self.alpha * 0.5) ) mpool = (z_down[1], 2, 2) if self.ndim == 3 else (2, 2) self.upsample8 = convtranspose(self.in_filters7, 32 * 2, kernel_size=mpool, stride=mpool) self.concat_filters3 = 32 * 2 * 2 self.multiresblock8 = Multiresblock(conv, batchnorm_layer, self.concat_filters3, 32 * 2) self.in_filters8 = ( int(32 * 2 * self.alpha * 0.167) + int(32 * 2 * self.alpha * 0.333) + int(32 * 2 * self.alpha * 0.5) ) mpool = (z_down[0], 2, 2) if self.ndim == 3 else (2, 2) self.upsample9 = convtranspose(self.in_filters8, 32, kernel_size=mpool, stride=mpool) self.concat_filters4 = 32 * 2 self.multiresblock9 = Multiresblock(conv, batchnorm_layer, self.concat_filters4, 32) self.in_filters9 = int(32 * self.alpha * 0.167) + int(32 * self.alpha * 0.333) + int(32 * self.alpha * 0.5) # Super-resolution self.post_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "post": self.post_upsampling = convtranspose( self.in_filters9, self.in_filters9, kernel_size=upsampling_factor, stride=upsampling_factor, ) if self.contrast: # extra added layers self.heads = nn.Sequential( conv(self.in_filters9, self.in_filters9, kernel_size=3, stride=1, padding=1), batchnorm_layer, dropout(0.10), conv(self.in_filters9, output_channels[0], kernel_size=1, stride=1, padding=0, bias=False), ) self.proj_head = ProjectionHead(ndim=self.ndim, in_channels=self.in_filters9, proj_dim=contrast_proj_dim) else: self.heads = nn.Sequential() for i, out_ch in enumerate(output_channels): self.heads.append(conv(self.in_filters9, out_ch, kernel_size=1, padding="same"))
[docs] def forward(self, x: torch.Tensor) -> Dict | torch.Tensor: """ Forward pass of the model. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels, height, width) for 2D or (batch_size, channels, depth, height, width) for 3D. Returns ------- Dict or torch.Tensor Model output. Returns a dictionary if multi-head or contrastive outputs are enabled, otherwise returns the main prediction tensor. """ # Super-resolution if self.pre_upsampling: x = self.pre_upsampling(x) x_multires1 = self.multiresblock1(x) x_pool1 = self.pool1(x_multires1) x_multires1 = self.respath1(x_multires1) x_multires2 = self.multiresblock2(x_pool1) x_pool2 = self.pool2(x_multires2) x_multires2 = self.respath2(x_multires2) x_multires3 = self.multiresblock3(x_pool2) x_pool3 = self.pool3(x_multires3) x_multires3 = self.respath3(x_multires3) x_multires4 = self.multiresblock4(x_pool3) x_pool4 = self.pool4(x_multires4) x_multires4 = self.respath4(x_multires4) x_multires5 = self.multiresblock5(x_pool4) up6 = torch.cat([self.upsample6(x_multires5), x_multires4], dim=1) x_multires6 = self.multiresblock6(up6) up7 = torch.cat([self.upsample7(x_multires6), x_multires3], dim=1) x_multires7 = self.multiresblock7(up7) up8 = torch.cat([self.upsample8(x_multires7), x_multires2], dim=1) x_multires8 = self.multiresblock8(up8) up9 = torch.cat([self.upsample9(x_multires8), x_multires1], dim=1) x_multires9 = self.multiresblock9(up9) feats = x_multires9 # Super-resolution if self.post_upsampling: feats = self.post_upsampling(feats) out_dict = {} # Pass the features through the output heads class_outs, outs = [], [] for i, head in enumerate(self.heads): if "class" not in self.output_channel_info[i]: outs.append(head(feats)) else: class_outs.append(head(feats)) outs = torch.cat(outs, dim=1) # Apply activations to the output heads if explicit_activations is True if self.explicit_activations: # If there is only one activation, apply it to the whole tensor if len(self.head_activations) == 1: outs = self.head_activations[0](outs) else: for i, act in enumerate(self.head_activations): outs[:, i:i+1] = act(outs[:, i:i+1]) if self.return_class and self.class_head_activations is not None: for i, act in enumerate(self.class_head_activations): class_outs[i] = act(class_outs[i]) out_dict = { "pred": outs, } if self.return_class: out_dict["class"] = torch.cat(class_outs, dim=1) # Contrastive learning head if self.contrast: out_dict["embed"] = self.proj_head(feats) if len(out_dict.keys()) == 1: return out_dict["pred"] else: if self.return_one_tensor: if "class" in out_dict: return torch.cat((out_dict["pred"], torch.argmax(out_dict["class"], dim=1).unsqueeze(1)), dim=1) else: return out_dict["pred"] return out_dict