biapy.models.multiresunet

This module implements the MultiResUNet architecture, a U-Net variant designed for multimodal biomedical image segmentation.

MultiResUNet enhances the standard U-Net by incorporating β€œMultiRes Blocks” in the encoder and decoder paths, and β€œResPaths” for skip connections. These components aim to improve feature representation and information flow across different scales.

Key components implemented in this file include:

Classes:

  • Conv_batchnorm: A basic convolutional block with optional batch normalization and activation.

  • Multiresblock: The MultiRes Block, a core component that processes features through parallel convolutional paths of different kernel sizes (3x3, 5x5, 7x7) and fuses them.

  • Respath: The ResPath module, which acts as an enhanced skip connection, applying residual convolutional blocks to features before they are concatenated in the decoder.

  • MultiResUnet: The main MultiResUNet model, combining the encoder, decoder, and skip connections using the MultiRes Blocks and ResPaths.

The implementation supports both 2D and 3D inputs, various normalization types, and optional multi-head outputs, including a contrastive learning projection.

Reference: MultiResUNet : Rethinking the U-Net Architecture for Multimodal Biomedical Image Segmentation

Code Adapted From: https://github.com/nibtehaz/MultiResUNet

class biapy.models.multiresunet.Conv_batchnorm(conv, batchnorm, num_in_filters, num_out_filters, kernel_size, stride=1, activation='relu')[source]

Bases: Module

A basic convolutional block with optional batch normalization and activation.

This module combines a convolutional layer, a batch normalization layer, and an activation function (ReLU by default), providing a standard building block for neural networks.

Parameters:
  • conv (Torch conv layer) – Convolutional layer type to use (e.g., nn.Conv2d, nn.Conv3d).

  • batchnorm (Torch batch normalization layer) – Batch normalization layer type to use (e.g., nn.BatchNorm2d, nn.BatchNorm3d).

  • num_in_filters (int) – Number of input channels for the convolutional layer.

  • num_out_filters (int) – Number of output channels for the convolutional layer.

  • kernel_size (Tuple of ints) – Size of the convolving kernel (e.g., 3 for 3x3, (3,3,3) for 3x3x3).

  • stride (Tuple of ints, optional) – Stride of the convolution. Defaults to 1.

  • activation (str, optional) – Activation function to apply after convolution and batch normalization. Currently supports β€œrelu” or β€œNone” (for no activation). Defaults to β€œrelu”.

forward(x)[source]

Perform the forward pass of the convolutional block.

Applies convolution, batch normalization, and then the specified activation function to the input tensor.

Parameters:

x (torch.Tensor) – The input tensor to the block.

Returns:

The output tensor after convolution, batch normalization, and activation.

Return type:

torch.Tensor

class biapy.models.multiresunet.Multiresblock(conv, batchnorm, num_in_channels, num_filters, alpha=1.67)[source]

Bases: Module

MultiRes Block as described in the MultiResUNet paper.

This block enhances feature extraction by processing input through parallel convolutional paths with different effective receptive fields (3x3, 5x5, 7x7) and then concatenating their outputs. It also includes a shortcut connection and batch normalization.

Parameters:
  • conv (Torch conv layer) – Convolutional layer type to use (e.g., nn.Conv2d, nn.Conv3d).

  • batchnorm (Torch batch normalization layer) – Batch normalization layer type to use (e.g., nn.BatchNorm2d, nn.BatchNorm3d).

  • num_in_channels (int) – Number of input channels coming into the MultiRes Block.

  • num_filters (int) – Base number of filters for calculating the output filter counts for the internal convolutional paths.

  • alpha (float, optional) – Alpha hyperparameter (default: 1.67). Used to scale the total number of filters, influencing the capacity of the block.

forward(x)[source]

Perform the forward pass of the MultiRes Block.

The input x first goes through a shortcut connection. Simultaneously, it passes through three sequential convolutional paths (3x3, 5x5, 7x7). The outputs of these paths are concatenated, batch normalized, and then added to the shortcut output. A final batch normalization and ReLU activation are applied.

Parameters:

x (torch.Tensor) – The input tensor to the MultiRes Block.

Returns:

The output tensor of the MultiRes Block.

Return type:

torch.Tensor

class biapy.models.multiresunet.Respath(conv, batchnorm, num_in_filters, num_out_filters, respath_length)[source]

Bases: Module

ResPath module for MultiResUNet.

ResPath acts as an enhanced skip connection by applying a series of residual convolutional blocks to the features before they are concatenated into the decoder path. This helps to reduce the semantic gap between encoder and decoder features.

Parameters:
  • conv (Torch conv layer) – Convolutional layer type to use (e.g., nn.Conv2d, nn.Conv3d).

  • batchnorm (Torch batch normalization layer) – Batch normalization layer type to use (e.g., nn.BatchNorm2d, nn.BatchNorm3d).

  • num_in_filters (int) – Number of input channels coming into the ResPath.

  • num_out_filters (int) – Number of output channels for each convolutional block within the ResPath.

  • respath_length (int) – The number of residual convolutional blocks to stack in the ResPath.

forward(x)[source]

Perform the forward pass of the ResPath.

The input x passes through a sequence of respath_length residual convolutional blocks. Each block involves a convolutional path and a shortcut connection, followed by batch normalization and ReLU activation.

Parameters:

x (torch.Tensor) – The input tensor to the ResPath.

Returns:

The output tensor after processing through all residual blocks in the ResPath.

Return type:

torch.Tensor

class biapy.models.multiresunet.MultiResUnet(ndim, input_channels, alpha=1.67, z_down=[2, 2, 2, 2], output_channels=[1], output_channel_info=['F'], explicit_activations: bool = False, head_activations: List[str] = ['ce_sigmoid'], upsampling_factor=(), upsampling_position='pre', contrast: bool = False, contrast_proj_dim: int = 256, return_one_tensor: bool = False)[source]

Bases: Module

forward(x: Tensor) Dict | Tensor[source]

Forward pass of the model.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, channels, height, width) for 2D or (batch_size, channels, depth, height, width) for 3D.

Returns:

Model output. Returns a dictionary if multi-head or contrastive outputs are enabled, otherwise returns the main prediction tensor.

Return type:

Dict or torch.Tensor