Source code for biapy.models.blocks

"""
This module contains a collection of fundamental building blocks for convolutional neural networks, primarily designed for biomedical image segmentation architectures like U-Nets and their variants.

It provides modular components for various operations, including:

- **Convolutional Layers:** Basic `ConvBlock` and `DoubleConvBlock` for standard feature extraction.
- **Attention Mechanisms:** `AttentionBlock` to integrate attention gating into skip connections.
- **Squeeze-and-Excitation Networks:** `SqExBlock` for channel-wise feature recalibration.
- **ConvNeXt Blocks:** Both `ConvNeXtBlock_V1` and `ConvNeXtBlock_V2` for modern, efficient
  feature processing with depthwise convolutions, layer normalization, and residual connections.
- **Global Response Normalization:** `GRN` layer used within ConvNeXt V2 for enhanced feature discrimination.
- **Upsampling Blocks:** `UpBlock`, `UpConvNeXtBlock_V1`, and `UpConvNeXtBlock_V2` for
  decoder paths, handling upsampling, skip connection concatenation, and feature refinement.

The blocks are designed to be flexible, supporting both 2D and 3D operations,
various normalization types, activations, and configurable parameters like kernel sizes and dropout.
"""
import torch
import torch.nn as nn
from torchvision.ops.stochastic_depth import StochasticDepth
from torchvision.ops.misc import Permute
from typing import Optional, Type, List, Tuple


[docs] class ConvBlock(nn.Module): """ Implements a standard Convolutional Block. This block consists of a convolutional layer followed by optional normalization, activation, dropout, and a Squeeze-and-Excitation (SE) block. It serves as a versatile building component in various convolutional neural network architectures. """ def __init__( self, conv, in_size, out_size, k_size, padding: int | str = "same", stride=1, bias=True, act=None, norm="none", dropout=0, se_block=False, ): """ Initialize the Convolutional Block. Sets up the core convolutional layer along with configurable normalization, activation, dropout, and an optional Squeeze-and-Excitation block. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use (e.g., `nn.Conv2d` for 2D, `nn.Conv3d` for 3D). in_size : int Number of input feature channels. out_size : int Number of output feature channels. k_size : int or tuple Kernel size for the convolutional layer. padding : int or str, optional Padding type for the convolutional layer. Can be an integer or "same". If "same", padding is calculated to maintain output spatial dimensions. Defaults to "same". stride : int or tuple, optional Stride for the convolutional layer. Defaults to 1. bias : bool, optional Whether to include a bias term in the convolutional layer. Defaults to `True`. act : Optional[str], optional Activation layer to use after normalization. E.g., "relu", "gelu". If `None`, no activation is applied. Defaults to `None`. norm : str, optional Normalization layer type to use after convolution. Options include `'bn'` (BatchNorm), `'sync_bn'` (SyncBatchNorm), `'in'` (InstanceNorm), `'gn'` (GroupNorm), or `'none'` (no normalization). Defaults to "none". dropout : float, optional Dropout probability to apply after activation (if any). If 0, no dropout. Defaults to 0. se_block : bool, optional Whether to add a Squeeze-and-Excitation (`SqExBlock`) after all other operations in the block. Defaults to `False`. """ super(ConvBlock, self).__init__() block = [] block.append(conv(in_size, out_size, kernel_size=k_size, padding=padding, stride=stride, bias=bias)) if norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(norm, out_size)) else: block.append(get_norm_3d(norm, out_size)) if act: block.append(get_activation(act)) if dropout > 0: block.append(nn.Dropout(dropout)) if se_block: block.append(SqExBlock(out_size, ndim=2 if conv == nn.Conv2d else 3)) self.block = nn.Sequential(*block)
[docs] def forward(self, x): """ Perform the forward pass of the Convolutional Block. Processes the input tensor sequentially through the defined layers: convolution, optional normalization, optional activation, optional dropout, and optional Squeeze-and-Excitation. Parameters ---------- x : torch.Tensor The input feature tensor. Expected shape for 2D: (batch_size, in_size, height, width). Expected shape for 3D: (batch_size, in_size, depth, height, width). Returns ------- torch.Tensor The output tensor after passing through the block. Its shape will be (batch_size, out_size, H', W') or (batch_size, out_size, D', H', W'), where H', W' (and D') depend on `padding` and `stride`. """ out = self.block(x) return out
[docs] class DoubleConvBlock(nn.Module): """ Implements a Double Convolutional Block. This block consists of two sequential ``ConvBlock`` layers. It is a common building component in many convolutional neural network architectures, especially in U-Net-like models, to extract features. """ def __init__( self, conv, in_size, out_size, k_size, act=None, norm="none", dropout=0, se_block=False, ): """ Initialize the Double Convolutional Block. Sets up two ConvBlock layers sequentially. The first ConvBlock transforms the input from in_size channels to out_size channels, and the second ConvBlock maintains out_size channels. Parameters ---------- conv : torch.nn.Module The convolutional layer type to use within each ConvBlock. Should be either torch.nn.Conv2d or torch.nn.Conv3d. in_size : int Number of input feature channels to the first ConvBlock. out_size : int Number of output feature channels for the entire DoubleConvBlock. Both internal ConvBlocks will output this number of channels. k_size : int or tuple Kernel size for the convolutional layers within each ConvBlock. act : str, optional Activation layer to use within each ConvBlock. Defaults to None. norm : str, optional Normalization layer type to use within each ConvBlock. Options include 'bn', 'sync_bn', 'in', 'gn', or 'none'. Defaults to "none". dropout : float, optional Dropout value to be fixed within each ConvBlock. Defaults to 0. se_block : bool, optional Whether to add a Squeeze-and-Excitation (SE) block at the end of each ConvBlock. Defaults to False. """ super(DoubleConvBlock, self).__init__() block = [] block.append( ConvBlock( conv=conv, in_size=in_size, out_size=out_size, k_size=k_size, act=act, norm=norm, dropout=dropout, se_block=se_block, ) ) block.append( ConvBlock( conv=conv, in_size=out_size, out_size=out_size, k_size=k_size, act=act, norm=norm, dropout=dropout, se_block=se_block, ) ) self.block = nn.Sequential(*block)
[docs] def forward(self, x): """ Perform the forward pass of the Double Convolutional Block. Processes the input tensor sequentially through the two `ConvBlock` layers. Parameters ---------- x : torch.Tensor The input feature tensor. Expected shape for 2D: (batch_size, in_size, height, width). Expected shape for 3D: (batch_size, in_size, depth, height, width). Returns ------- torch.Tensor The output tensor after passing through both `ConvBlock` layers. Its shape will be (batch_size, out_size, H', W') or (batch_size, out_size, D', H', W'), where H', W' (and D') match the input spatial dimensions if padding is 'same'. """ out = self.block(x) return out
[docs] class ConvNeXtBlock_V1(nn.Module): """ Implements a single ConvNeXt V1 block. This block is a fundamental building component of ConvNeXt V1 networks, featuring a depthwise convolution, a LayerNorm-Linear-GELU-Linear path, layer scaling, and a stochastic depth residual connection. """ def __init__(self, ndim, conv, dim, layer_scale=1e-6, stochastic_depth_prob=0.0, layer_norm=None, k_size=7): """ Initialize the ConvNeXt V1 block. Sets up the depthwise convolution, permutation-aware Layer Normalization, an MLP with GELU activation, optional learnable layer scaling, and a stochastic depth regularizer for the residual connection. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use for the depthwise convolution. dim : int Number of input and output channels for the block. layer_scale : float, optional Initial value for the learnable layer scale parameter. If > 0, a `nn.Parameter` is created for scaling. Defaults to 1e-6. stochastic_depth_prob : float, optional The probability of dropping the residual branch during training. Defaults to 0.0 (no dropout). layer_norm : Optional[Type[nn.LayerNorm]], optional The Layer Normalization layer type to use. If `None`, `nn.LayerNorm` is used. Defaults to `None`. k_size : int or tuple, optional Height, width, and depth (for 3D) of the depthwise convolution window. Defaults to 7. """ super().__init__() if layer_norm is None: layer_norm = nn.LayerNorm if ndim == 3: pre_ln_permutation = Permute([0, 2, 3, 4, 1]) post_ln_permutation = Permute([0, 4, 1, 2, 3]) layer_scale_dim = (dim, 1, 1, 1) pad = (0, 3, 3) if k_size[0] == 1 else (3, 3, 3) elif ndim == 2: pre_ln_permutation = Permute([0, 2, 3, 1]) post_ln_permutation = Permute([0, 3, 1, 2]) layer_scale_dim = (dim, 1, 1) pad = (3, 3) self.block = nn.Sequential( conv(dim, dim, kernel_size=k_size, padding=pad, groups=dim, bias=True), # depthwise conv pre_ln_permutation, layer_norm(dim, eps=1e-6), nn.Linear( in_features=dim, out_features=4 * dim, bias=True ), # pointwise/1x1 convs, implemented with linear layers nn.GELU(), nn.Linear(in_features=4 * dim, out_features=dim, bias=True), post_ln_permutation, ) self.layer_scale = ( nn.Parameter(torch.ones(layer_scale_dim) * layer_scale, requires_grad=True) if layer_scale > 0 else None ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
[docs] def forward(self, x): """ Perform the forward pass of the ConvNeXt V1 block. Processes the input through a depthwise convolution, layer normalization, and an MLP with GELU. The output of this path is optionally scaled by a learnable layer scale parameter and then added to the original input via a residual connection, applying stochastic depth. Parameters ---------- x : torch.Tensor The input feature tensor. Expected shape for 2D: (batch_size, dim, height, width). Expected shape for 3D: (batch_size, dim, depth, height, width). Returns ------- torch.Tensor The output tensor of the ConvNeXt V1 block, with the same shape as the input. """ result = self.block(x) if self.layer_scale is not None: result = self.layer_scale * result result = x + self.stochastic_depth(result) return result
[docs] class GRN(nn.Module): """ Implement the Global Response Normalization (GRN) layer. This layer enhances feature discrimination by normalizing features based on global responses across channels, as introduced in ConvNeXt V2. It includes learnable parameters for scaling and shifting. """ def __init__(self, dim): """ Initialize the Global Response Normalization (GRN) layer. Sets up learnable scaling (`gamma`) and biasing (`beta`) parameters for the normalization process, initialized to zeros. Parameters ---------- dim : int The number of input feature channels/dimensions. """ super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
[docs] def forward(self, x): """ Perform the forward pass of the GRN layer. Calculates the L2 norm (`Gx`) across spatial dimensions for each channel. Then, it normalizes `Gx` by its mean across channels (`Nx`). Finally, it applies the learnable `gamma` and `beta` parameters to `x * Nx` and adds the original input `x` as a residual connection. Parameters ---------- x : torch.Tensor The input feature tensor. Expected to be in a channel-last format for proper normalization (e.g., [B, D, H, W, C] for 3D, or [B, H, W, C] for 2D) when `gamma` and `beta` are applied. Assuming the input `x` is permuted to (B, spatial_dims..., C) before this layer. Returns ------- torch.Tensor The normalized feature tensor, with the same shape as the input `x`. """ # Gx: L2 norm over spatial dimensions, keepdim=True for broadcasting # Assuming x is (B, D, H, W, C) or (B, H, W, C) due to pre/post LN permutation Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) # Nx: Normalize Gx by its mean across channels Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) # Apply gamma, beta, and add residual return self.gamma * (x * Nx) + self.beta + x
[docs] class ConvNeXtBlock_V2(nn.Module): """ Implements a single ConvNeXt V2 block. This block is a fundamental building component of ConvNeXt V2 networks, featuring a depthwise convolution, a Permute-LayerNorm-Linear-GELU-GRN-Linear-Permute path, and a stochastic depth residual connection. """ def __init__(self, ndim, conv, dim, stochastic_depth_prob=0.0, layer_norm=None, k_size=7): """ Initialize the ConvNeXt V2 block. Sets up the depthwise convolution, a permutation-aware Layer Normalization, an MLP with GELU and Global Response Normalization (GRN), and a stochastic depth regularizer for the residual connection. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use for the depthwise convolution. dim : int Number of input and output channels for the block. stochastic_depth_prob : float, optional The probability of dropping the residual branch during training. Defaults to 0.0 (no dropout). layer_norm : Optional[Type[nn.LayerNorm]], optional The Layer Normalization layer type to use. If `None`, `nn.LayerNorm` is used. Defaults to `None`. k_size : int or tuple, optional Height, width, and depth (for 3D) of the depthwise convolution window. Defaults to 7. """ super().__init__() if layer_norm is None: layer_norm = nn.LayerNorm if ndim == 3: pre_ln_permutation = Permute([0, 2, 3, 4, 1]) post_ln_permutation = Permute([0, 4, 1, 2, 3]) pad = (0, 3, 3) if k_size[0] == 1 else (3, 3, 3) elif ndim == 2: pre_ln_permutation = Permute([0, 2, 3, 1]) post_ln_permutation = Permute([0, 3, 1, 2]) pad = (3, 3) self.block = nn.Sequential( conv(dim, dim, kernel_size=k_size, padding=pad, groups=dim, bias=True), # depthwise conv pre_ln_permutation, layer_norm(dim, eps=1e-6), nn.Linear(in_features=dim, out_features=4 * dim, bias=True), nn.GELU(), GRN(4 * dim), nn.Linear(in_features=4 * dim, out_features=dim, bias=True), post_ln_permutation, ) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
[docs] def forward(self, x): """ Perform the forward pass of the ConvNeXt V2 block. Processes the input through a depthwise convolution, layer normalization, and an MLP with GELU and GRN. The output of this path is then added to the original input via a residual connection, optionally applying stochastic depth. Parameters ---------- x : torch.Tensor The input feature tensor. Expected shape for 2D: (batch_size, dim, height, width). Expected shape for 3D: (batch_size, dim, depth, height, width). Returns ------- torch.Tensor The output tensor of the ConvNeXt V2 block, with the same shape as the input. """ result = self.block(x) result = x + self.stochastic_depth(result) return result
[docs] class UpBlock(nn.Module): """ Implements a standard Upsampling block, commonly used in the decoder path of U-Net-like architectures. This block performs an upsampling operation, concatenates the upsampled features with a skip connection (bridge) from the encoder, and then processes the combined features through a `DoubleConvBlock`. It supports different upsampling modes and optional attention gating. """ def __init__( self, ndim, convtranspose, in_size, out_size, z_down, yx_down, up_mode, conv, k_size, act=None, norm="none", dropout=0, attention_gate=False, se_block=False, ): """ Initialize the Upsampling block. Sets up the upsampling layer (either transpose convolution or `nn.Upsample` followed by convolution), an optional attention gate, and a `DoubleConvBlock` for feature refinement after concatenation. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). convtranspose : Type[nn.ConvTranspose2d | nn.ConvTranspose3d] The transpose convolutional layer type to use if `up_mode` is 'convtranspose'. in_size : int Number of input channels from the previous decoder stage (input to upsampling). out_size : int Number of output channels for this upsampling block after concatenation and processing. z_down : int Downsampling factor applied in the z-dimension for 3D data during upsampling. Only relevant if `ndim` is 3. yx_down : int Downsampling factor applied in the y and x dimensions during upsampling. For isotropic data, this should match `z_down`. For anisotropic data, set accordingly (e.g., `yx_down=2` and `z_down=1` for 2D-like anisotropic data). up_mode : str The upsampling mode to use. - 'convtranspose': Uses a transpose convolution (`convtranspose`) for upsampling. - 'upsampling': Uses `nn.Upsample` (bilinear for 2D, trilinear for 3D) followed by a 1x1 convolution to adjust channels. conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use within internal blocks (e.g., `DoubleConvBlock`). k_size : int or tuple Kernel size for the convolutional layers within the `DoubleConvBlock`. act : Optional[str], optional Activation layer to use within the `DoubleConvBlock`. Defaults to `None`. norm : str, optional Normalization layer type to use within the upsampling path and `DoubleConvBlock`. Options include `'bn'`, `'sync_bn'`, `'in'`, `'gn'`, or `'none'`. Defaults to "none". dropout : float, optional Dropout value to be fixed within the `DoubleConvBlock`. Defaults to 0. attention_gate : bool, optional Whether to use an attention gate (`AttentionBlock`) to gate the skip connection before concatenation. Defaults to `False`. se_block : bool, optional Whether to add a Squeeze-and-Excitation (SE) block within the `DoubleConvBlock`. Defaults to `False`. """ super(UpBlock, self).__init__() self.ndim = ndim block = [] mpool = (z_down, yx_down, yx_down) if ndim == 3 else (yx_down, yx_down) if up_mode == "convtranspose": block.append(convtranspose(in_size, out_size, kernel_size=mpool, stride=mpool)) elif up_mode == "upsampling": block.append(nn.Upsample(mode="bilinear" if ndim == 2 else "trilinear", scale_factor=mpool)) block.append(conv(in_size, out_size, kernel_size=1)) if norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(norm, out_size)) else: block.append(get_norm_3d(norm, out_size)) if act is not None: block.append(get_activation(act)) self.up = nn.Sequential(*block) if attention_gate: self.attention_gate = AttentionBlock(conv=conv, in_size=out_size, out_size=out_size // 2, norm=norm) else: self.attention_gate = None self.conv_block = DoubleConvBlock( conv=conv, in_size=out_size * 2, out_size=out_size, k_size=k_size, act=act, norm=norm, dropout=dropout, se_block=se_block, )
[docs] def forward(self, x, bridge): """ Perform the forward pass of the Upsampling block. First, it upsamples the input tensor `x`. If an attention gate is enabled, it uses the upsampled `x` and the `bridge` tensor to compute attention, then concatenates the upsampled `x` with the (potentially attended) `bridge`. Finally, the concatenated tensor is processed by a `DoubleConvBlock`. Parameters ---------- x : torch.Tensor The input feature tensor from the previous decoder stage (lower resolution). Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). bridge : torch.Tensor The skip connection tensor from the corresponding encoder stage (higher resolution). Expected shape: (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), where D', H', W' match the spatial dimensions after upsampling `x`. Returns ------- torch.Tensor The output tensor of the upsampling block. Its shape will be (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), matching the upsampled spatial dimensions and `out_size` channels. """ up = self.up(x) if self.attention_gate is not None: attn = self.attention_gate(up, bridge) out = torch.cat([up, attn], 1) else: out = torch.cat([up, bridge], 1) out = self.conv_block(out) return out
[docs] class UpConvNeXtBlock_V1(nn.Module): """ Implements an Upsampling block using ConvNeXt V1 components. This block is designed for the upsampling path of U-Net-like architectures, combining upsampling with concatenation of skip connections, optional attention gating, and a sequence of ConvNeXt V1 blocks for feature refinement. """ def __init__( self, ndim, convtranspose, in_size, out_size, z_down, yx_down, up_mode, conv, attention_gate=False, se_block=False, cn_layers=1, sd_probs=[0.0], layer_scale=1e-6, layer_norm=None, k_size=7, ): """ Initialize an Upsampling block using ConvNeXt V1 components. This block is designed for the upsampling path of U-Net-like architectures, combining upsampling with concatenation of skip connections, optional attention gating, and a sequence of ConvNeXt V1 blocks for feature refinement. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). convtranspose : Type[nn.ConvTranspose2d | nn.ConvTranspose3d] The transpose convolutional layer type to use. Only used if ``up_mode`` is ``'convtranspose'``. in_size : int Number of input channels from the previous decoder stage (input to upsampling). out_size : int Number of output channels for this upsampling block after concatenation and processing. z_down : int, optional Downsampling factor applied in the z-dimension for 3D data during upsampling. Only relevant if `ndim` is 3. Defaults to 2. yx_down : int Downsampling factor applied in the y and x dimensions during upsampling. For isotropic data, this should match `z_down`. For anisotropic data, set accordingly (e.g., `yx_down=2` and `z_down=1` for 2D-like anisotropic data). up_mode : str The upsampling mode to use. - 'convtranspose': Uses a transpose convolution (`convtranspose`) for upsampling. - 'upsampling': Uses `nn.Upsample` (bilinear for 2D, trilinear for 3D) followed by a 1x1 convolution to adjust channels. conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use within internal blocks (e.g., `ConvBlock`). attention_gate : bool, optional Whether to use an attention gate (`AttentionBlock`) to gate the skip connection before concatenation. Defaults to `False`. se_block : bool, optional Whether to add a Squeeze-and-Excitation (SE) block within the initial `ConvBlock` used after concatenation. Defaults to `False`. cn_layers : int, optional Number of `ConvNeXtBlock_V1` layers to stack after the initial concatenation and convolution. Defaults to 1. sd_probs : list of float, optional List of stochastic depth probabilities for each `ConvNeXtBlock_V1` layer. The length of this list should match `cn_layers`. Defaults to `[0.0]`. layer_scale : float, optional Layer scale parameter used in `ConvNeXtBlock_V1`. Defaults to 1e-6. layer_norm : Optional[nn.LayerNorm], optional The Layer Normalization layer type to use. If `None`, `nn.LayerNorm` is used. This normalization is applied before upsampling. Defaults to `None`. k_size : int or tuple, optional Height, width, and depth (for 3D) of the depthwise convolution window within the `ConvNeXtBlock_V1` layers. Defaults to 7. """ super(UpConvNeXtBlock_V1, self).__init__() self.ndim = ndim block = [] mpool = (z_down, yx_down, yx_down) if ndim == 3 else (yx_down, yx_down) if ndim == 3: pre_ln_permutation = Permute([0, 2, 3, 4, 1]) post_ln_permutation = Permute([0, 4, 1, 2, 3]) else: pre_ln_permutation = Permute([0, 2, 3, 1]) post_ln_permutation = Permute([0, 3, 1, 2]) if layer_norm is not None: block.append(nn.Sequential(pre_ln_permutation, layer_norm(in_size), post_ln_permutation)) else: layer_norm = nn.LayerNorm block.append(nn.Sequential(pre_ln_permutation, layer_norm(in_size), post_ln_permutation)) # Upsampling if up_mode == "convtranspose": block.append(convtranspose(in_size, out_size, kernel_size=mpool, stride=mpool)) elif up_mode == "upsampling": block.append(nn.Upsample(mode="bilinear" if ndim == 2 else "trilinear", scale_factor=mpool)) block.append(conv(in_size, out_size, kernel_size=1)) self.up = nn.Sequential(*block) # Define attention gate if attention_gate: self.attention_gate = AttentionBlock(conv=conv, in_size=out_size, out_size=out_size // 2) else: self.attention_gate = None # Convolution block to change dimensions of concatenated tensor self.conv_block = ConvBlock(conv, in_size=out_size * 2, out_size=out_size, k_size=1, se_block=se_block) # ConvNeXtBlock stage = nn.ModuleList() for i in reversed(range(cn_layers)): stage.append( ConvNeXtBlock_V1(ndim, conv, out_size, layer_scale, sd_probs[i], layer_norm=layer_norm, k_size=k_size) ) self.cn_block = nn.Sequential(*stage)
[docs] def forward(self, x, bridge): """ Perform the forward pass of the UpConvNeXtBlock_V1. First, it upsamples the input tensor `x`. If an attention gate is enabled, it uses the upsampled `x` and the `bridge` tensor to compute attention, then concatenates the upsampled `x` with the (potentially attended) `bridge`. Finally, the concatenated tensor is processed by an initial convolutional block and then refined through a sequence of ConvNeXt V1 blocks. Parameters ---------- x : torch.Tensor The input feature tensor from the previous decoder stage (lower resolution). Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). bridge : torch.Tensor The skip connection tensor from the corresponding encoder stage (higher resolution). Expected shape: (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), where D', H', W' match the spatial dimensions after upsampling `x`. Returns ------- torch.Tensor The output tensor of the upsampling block. Its shape will be (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), matching the upsampled spatial dimensions and `out_size` channels. """ up = self.up(x) if self.attention_gate is not None: attn = self.attention_gate(up, bridge) out = torch.cat([up, attn], 1) else: out = torch.cat([up, bridge], 1) out = self.conv_block(out) out = self.cn_block(out) return out
[docs] class UpConvNeXtBlock_V2(nn.Module): """ Implements an Upsampling block using ConvNeXt V2 components. This block is designed for the upsampling path of U-Net-like architectures, combining upsampling with concatenation of skip connections, optional attention gating, and a sequence of ConvNeXt V2 blocks for feature refinement. """ def __init__( self, ndim, convtranspose, in_size, out_size, z_down, yx_down, up_mode, conv, attention_gate=False, se_block=False, cn_layers=1, sd_probs=[0.0], layer_norm=None, k_size=7, ): """ Initialize an Upsampling block using ConvNeXt V2 components. This block is designed for the upsampling path of U-Net-like architectures, combining upsampling with concatenation of skip connections, optional attention gating, and a sequence of ConvNeXt V2 blocks for feature refinement. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). convtranspose : Type[nn.ConvTranspose2d | nn.ConvTranspose3d] The transpose convolutional layer type to use if `up_mode` is 'convtranspose'. in_size : int Number of input channels from the previous decoder stage (input to upsampling). out_size : int Number of output channels for this upsampling block after concatenation and processing. z_down : int, optional Downsampling factor applied in the z-dimension for 3D data during upsampling. Only relevant if `ndim` is 3. Defaults to 2. yx_down : int, optional Downsampling factor applied in the y and x dimensions during upsampling. For isotropic data, this should match `z_down`. For anisotropic data, set accordingly (e.g., `yx_down=2` and `z_down=1` for 2D-like anisotropic data). Defaults to 2. up_mode : str The upsampling mode to use. - 'convtranspose': Uses a transpose convolution (`convtranspose`) for upsampling. - 'upsampling': Uses `nn.Upsample` (bilinear for 2D, trilinear for 3D) followed by a 1x1 convolution to adjust channels. conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use within internal blocks (e.g., `ConvBlock`). attention_gate : bool, optional Whether to use an attention gate (`AttentionBlock`) to gate the skip connection before concatenation. Defaults to `False`. se_block : bool, optional Whether to add a Squeeze-and-Excitation (SE) block within the initial `ConvBlock` used after concatenation. Defaults to `False`. cn_layers : int, optional Number of `ConvNeXtBlock_V2` layers to stack after the initial concatenation and convolution. Defaults to 1. sd_probs : list of float, optional List of stochastic depth probabilities for each `ConvNeXtBlock_V2` layer. The length of this list should match `cn_layers`. Defaults to `[0.0]`. layer_norm : Optional[nn.LayerNorm], optional The Layer Normalization layer type to use. If `None`, `nn.LayerNorm` is used. This normalization is applied before upsampling. Defaults to `None`. k_size : int or tuple, optional Height, width, and depth (for 3D) of the depthwise convolution window within the `ConvNeXtBlock_V2` layers. Defaults to 7. """ super(UpConvNeXtBlock_V2, self).__init__() self.ndim = ndim block = [] mpool = (z_down, yx_down, yx_down) if ndim == 3 else (yx_down, yx_down) if ndim == 3: pre_ln_permutation = Permute([0, 2, 3, 4, 1]) post_ln_permutation = Permute([0, 4, 1, 2, 3]) else: pre_ln_permutation = Permute([0, 2, 3, 1]) post_ln_permutation = Permute([0, 3, 1, 2]) if layer_norm is not None: block.append(nn.Sequential(pre_ln_permutation, layer_norm(in_size), post_ln_permutation)) else: layer_norm = nn.LayerNorm block.append(nn.Sequential(pre_ln_permutation, layer_norm(in_size), post_ln_permutation)) # Upsampling if up_mode == "convtranspose": block.append(convtranspose(in_size, out_size, kernel_size=mpool, stride=mpool)) elif up_mode == "upsampling": block.append(nn.Upsample(mode="bilinear" if ndim == 2 else "trilinear", scale_factor=mpool)) block.append(conv(in_size, out_size, kernel_size=1)) self.up = nn.Sequential(*block) # Define attention gate if attention_gate: self.attention_gate = AttentionBlock(conv=conv, in_size=out_size, out_size=out_size // 2) else: self.attention_gate = None # Convolution block to change dimensions of concatenated tensor self.conv_block = ConvBlock(conv, in_size=out_size * 2, out_size=out_size, k_size=1, se_block=se_block) # ConvNeXtBlock stage = nn.ModuleList() for i in reversed(range(cn_layers)): stage.append(ConvNeXtBlock_V2(ndim, conv, out_size, sd_probs[i], layer_norm=layer_norm, k_size=k_size)) self.cn_block = nn.Sequential(*stage)
[docs] def forward(self, x, bridge): """ Perform the forward pass of the UpConvNeXtBlock_V2. First, it upsamples the input tensor `x`. If an attention gate is enabled, it uses the upsampled `x` and the `bridge` tensor to compute attention, then concatenates the upsampled `x` with the (potentially attended) `bridge`. Finally, the concatenated tensor is processed by an initial convolutional block and then refined through a sequence of ConvNeXt V2 blocks. Parameters ---------- x : torch.Tensor The input feature tensor from the previous decoder stage (lower resolution). Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). bridge : torch.Tensor The skip connection tensor from the corresponding encoder stage (higher resolution). Expected shape: (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), where D', H', W' match the spatial dimensions after upsampling `x`. Returns ------- torch.Tensor The output tensor of the upsampling block. Its shape will be (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), matching the upsampled spatial dimensions and `out_size` channels. """ up = self.up(x) if self.attention_gate is not None: attn = self.attention_gate(up, bridge) out = torch.cat([up, attn], 1) else: out = torch.cat([up, bridge], 1) out = self.conv_block(out) out = self.cn_block(out) return out
[docs] class AttentionBlock(nn.Module): """ Implements an Attention Block, as proposed in Attention U-Net. This block refines skip connections in U-Net-like architectures by generating attention coefficients. It learns to focus on salient features from the skip pathway (`x`) guided by the features from the coarser, upsampled pathway (`g`). Reference: `Attention U-Net: Learning Where to Look for the Pancreas <https://arxiv.org/abs/1804.03999>`_. """ def __init__(self, conv, in_size, out_size, norm="none"): """ Initialize the Attention Block with convolutional layers for gating and input signals. Sets up three distinct convolutional pathways: `w_g` for the gating signal, `w_x` for the skip connection input, and `psi` for generating the attention map. Each pathway includes optional normalization. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use (e.g., `nn.Conv2d` for 2D, `nn.Conv3d` for 3D). in_size : int Number of input feature channels for both the gating signal (`g`) and the skip connection input (`x`). out_size : int Number of output channels for the intermediate convolutional layers (`w_g` and `w_x` outputs). The `psi` layer reduces this to 1 channel. norm : str, optional Normalization layer type to use within the convolutional sub-blocks. Options include `'bn'` (BatchNorm), `'sync_bn'` (SyncBatchNorm), `'in'` (InstanceNorm), `'gn'` (GroupNorm), or `'none'` (no normalization). Defaults to "none". """ super(AttentionBlock, self).__init__() w_g = [] w_g.append(conv(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=True)) if norm != "none": if conv == nn.Conv2d: w_g.append(get_norm_2d(norm, out_size)) else: w_g.append(get_norm_3d(norm, out_size)) self.w_g = nn.Sequential(*w_g) w_x = [] w_x.append(conv(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=True)) if norm != "none": if conv == nn.Conv2d: w_g.append(get_norm_2d(norm, out_size)) else: w_g.append(get_norm_3d(norm, out_size)) self.w_x = nn.Sequential(*w_x) psi = [] psi.append(conv(out_size, 1, kernel_size=1, stride=1, padding=0, bias=True)) if norm != "none": if conv == nn.Conv2d: psi.append(get_norm_2d(norm, 1)) else: psi.append(get_norm_3d(norm, 1)) psi.append(nn.Sigmoid()) self.psi = nn.Sequential(*psi) self.relu = nn.ReLU(inplace=True)
[docs] def forward(self, g, x): """ Perform the forward pass of the Attention Block. Processes the gating signal `g` and the skip connection `x` independently through convolutional layers. Their outputs are summed, passed through a ReLU activation, and then a 1x1 convolution with Sigmoid activation generates the attention coefficients (`psi`). Finally, these attention coefficients are multiplied element-wise with the skip connection `x` to produce the attention-gated output. Parameters ---------- g : torch.Tensor The gating signal tensor from the coarser (upsampled) pathway. Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). x : torch.Tensor The skip connection tensor from the corresponding encoder pathway. Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). Spatial dimensions must match those of `g` after any necessary upsampling of `g`. Returns ------- torch.Tensor The attention-gated feature tensor, with the same shape as `x`. """ g1 = self.w_g(g) x1 = self.w_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return psi * x
[docs] class SqExBlock(nn.Module): """ Implements the Squeeze-and-Excitation (SE) block, a computational unit that adaptively recalibrates channel-wise feature responses. This block enhances the representational power of a network by explicitly modeling interdependencies between channels, allowing the network to perform feature recalibration. Reference: `Squeeze and Excitation Networks <https://arxiv.org/abs/1709.01507>`_. Credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4 """ def __init__(self, c, r=16, ndim=2): """ Initialize the Squeeze-and-Excitation block. Sets up the squeeze operation (global average pooling) and the excitation operation (two fully connected layers with ReLU and Sigmoid activations). Parameters ---------- c : int Number of input channels to the block. r : int, optional Reduction ratio for the number of channels in the excitation branch. The hidden dimension will be `c // r`. Defaults to 16. ndim : int, optional Number of dimensions of the input data. Use 2 for 2D data (e.g., images) which implies `nn.AdaptiveAvgPool2d`. Use 3 for 3D data (e.g., volumetric scans) which implies `nn.AdaptiveAvgPool3d`. Defaults to 2. """ super().__init__() self.ndim = ndim self.squeeze = nn.AdaptiveAvgPool2d(1) if ndim == 2 else nn.AdaptiveAvgPool3d(1) self.excitation = nn.Sequential( nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), nn.Linear(c // r, c, bias=False), nn.Sigmoid(), )
[docs] def forward(self, x): """ Perform the forward pass of the Squeeze-and-Excitation block. Applies a global average pooling (squeeze) to the input to aggregate spatial information into a channel descriptor. This descriptor is then passed through a fully connected excitation network to predict channel-wise attention weights. Finally, these weights are applied to the input feature map by channel-wise multiplication. Parameters ---------- x : torch.Tensor The input feature tensor. Expected shape for 2D: (batch_size, channels, height, width). Expected shape for 3D: (batch_size, channels, depth, height, width). Returns ------- torch.Tensor The recalibrated feature tensor, with the same shape as the input `x`. """ bs = x.shape[0] c = x.shape[1] y = self.squeeze(x).view(bs, c) y = self.excitation(y) if self.ndim == 2: y = y.view(bs, c, 1, 1) else: y = y.view(bs, c, 1, 1, 1) return x * y.expand_as(x)
[docs] class ResConvBlock(nn.Module): """ Implements a Residual Convolutional Block. This block is a core component often used in U-Net-like architectures to build encoder and decoder paths. It consists of a sequence of convolutional layers with a skip connection, allowing for better gradient flow and feature reuse. It supports optional pre-activation, Squeeze-and-Excitation blocks, and an initial extra convolutional layer. """ def __init__( self, conv, in_size, out_size, k_size, act=None, norm="none", dropout=0, skip_k_size: int | Tuple[int, ...] =1, skip_norm="none", first_block=False, se_block=False, extra_conv=False, ): """ Initialize a Residual Convolutional Block. This block is a core component often used in U-Net-like architectures to build encoder and decoder paths. It consists of a sequence of convolutional layers with a skip connection, allowing for better gradient flow and feature reuse. It supports optional pre-activation, Squeeze-and-Excitation blocks, and an initial extra convolutional layer. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use (e.g., `nn.Conv2d` for 2D, `nn.Conv3d` for 3D). in_size : int Number of input feature channels to the block. out_size : int Number of output feature channels for the convolutional layers within the block. k_size : int or tuple Kernel size for the main convolutional layers within the block. act : Optional[str], optional Activation layer to use after the first main convolution. If `None`, no activation is applied. Defaults to `None`. norm : str, optional Normalization layer type to use within the `ConvBlock` components. Options include `'bn'` (BatchNorm), `'sync_bn'` (SyncBatchNorm), `'in'` (InstanceNorm), `'gn'` (GroupNorm), or `'none'` (no normalization). Defaults to "none". dropout : float, optional Dropout value to be fixed within the `ConvBlock` components. If 0, no dropout is applied. Defaults to 0. skip_k_size : int or tuple, optional Kernel size for the convolution in the skip connection path. Used to adjust channel dimensions if `in_size` and `out_size` differ or to ensure correct output shape. Defaults to 1. skip_norm : str, optional Normalization layer type to use in the skip connection path. Options are `'bn'`, `'sync_bn'`, `'in'`, `'gn'`, or `'none'`. Defaults to "none". first_block : bool, optional If `True`, indicates that this is the first residual block in a sequence, which affects the application of Full Pre-Activation layers (normalization and activation are not applied before the first convolution in this case). Defaults to `False`. Reference: `Identity Mappings in Deep Residual Networks <https://arxiv.org/pdf/1603.05027.pdf>`_. se_block : bool, optional Whether to add a Squeeze-and-Excitation (SE) block at the end of the full residual block. Defaults to `False`. extra_conv : bool, optional If `True`, adds an additional convolutional layer with pre-activation before the main residual path, as described in Kisuk et al, 2017. Reference: `https://arxiv.org/pdf/1706.00120`. Defaults to `False`. """ super(ResConvBlock, self).__init__() block = [] pre_conv = [] if not first_block: if not extra_conv: if norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(norm, in_size)) else: block.append(get_norm_3d(norm, in_size)) if act is not None: block.append(get_activation(act)) else: if norm != "none": if conv == nn.Conv2d: pre_conv.append(get_norm_2d(norm, in_size)) else: pre_conv.append(get_norm_3d(norm, in_size)) if act is not None: pre_conv.append(get_activation(act)) if extra_conv: pre_conv.append( ConvBlock( conv=conv, in_size=in_size, out_size=out_size, k_size=k_size, act=act, norm=norm, dropout=dropout, ) ) in_size = out_size self.pre_conv = nn.Sequential(*pre_conv) else: self.pre_conv = None block.append( ConvBlock( conv=conv, in_size=in_size, out_size=out_size, k_size=k_size, act=act, norm=norm, dropout=dropout, ) ) block.append(ConvBlock(conv=conv, in_size=out_size, out_size=out_size, k_size=k_size)) self.block = nn.Sequential(*block) if not extra_conv: block = [] block.append(conv(in_size, out_size, kernel_size=skip_k_size, padding="same")) if skip_norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(skip_norm, out_size)) else: block.append(get_norm_3d(skip_norm, out_size)) self.shortcut = nn.Sequential(*block) else: self.shortcut = nn.Identity() if se_block: # add the Squeeze-and-Excitation block at the end of the full block (as in PyTC) # (https://github.com/zudi-lin/pytorch_connectomics/blob/master/connectomics/model/block/residual.py#L147-L155) self.se_block = SqExBlock(out_size, ndim=2 if conv == nn.Conv2d else 3) else: self.se_block = nn.Identity()
[docs] def forward(self, x): """ Perform the forward pass through the Residual Convolutional Block. Processes the input tensor through an optional pre-convolutional layer, then through the main convolutional blocks, and finally adds a skip connection. An optional Squeeze-and-Excitation block is applied at the end. Parameters ---------- x : torch.Tensor The input feature tensor. Its shape should be (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). Returns ------- torch.Tensor The output tensor after processing through the residual block. Its shape will be (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), where D', H', W' match the input spatial dimensions if `padding="same"` is used. """ if self.pre_conv is not None: x = self.pre_conv(x) out = self.block(x) + self.shortcut(x) return self.se_block(out)
[docs] class ResUpBlock(nn.Module): """ Implements a Residual Upsampling block, typically used in the decoder path of U-Net-like architectures. This block performs an upsampling operation on the input feature map, concatenates it with a corresponding skip connection (bridge) from the encoder path, and then processes the combined features through a `ResConvBlock`. It supports different upsampling modes and integrates residual connections for improved feature propagation. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). convtranspose : Type[nn.ConvTranspose2d | nn.ConvTranspose3d] The transpose convolutional layer type to use if `up_mode` is 'convtranspose'. in_size : int Number of input channels to the upsampling operation (from the previous decoder stage). out_size : int Number of output channels for the final `ResConvBlock` in this upsampling stage. in_size_bridge : int Number of channels of the skip connection (bridge) tensor from the encoder path. z_down : int, optional Downsampling factor applied in the z-dimension for 3D data during upsampling. Only relevant if `ndim` is 3. Defaults to 2. yx_down : int, optional Downsampling factor applied in the y and x dimensions for 2D and 3D data during upsampling. Only relevant if `ndim` is 2 or 3. Defaults to 2. up_mode : str The upsampling mode to use. - 'convtranspose': Uses a transpose convolution (`convtranspose`) for upsampling. - 'upsampling': Uses `nn.Upsample` (bilinear for 2D, trilinear for 3D) followed by a 1x1 convolution. conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use within the internal `ResConvBlock`. k_size : int or tuple Kernel size for the convolutional layers within the `ResConvBlock`. act : str, optional Activation function to use within the `ResConvBlock`. Defaults to `None`. norm : str, optional Normalization layer type to use within the `ResConvBlock`. Options include `'bn'`, `'sync_bn'`, `'in'`, `'gn'`, or `'none'`. Defaults to "none". skip_k_size : int or tuple, optional Kernel size for the skip connection convolution within the `ResConvBlock`. Used in ResUNet++. Defaults to 1. skip_norm : str, optional Normalization layer type for the skip connection within the `ResConvBlock`. Defaults to "none". dropout : float, optional Dropout value to be fixed within the `ResConvBlock`. Defaults to 0. se_block : bool, optional Whether to add Squeeze-and-Excitation blocks within the `ResConvBlock`. Defaults to `False`. extra_conv : bool, optional Whether to add an extra convolutional layer before the residual block within the `ResConvBlock` (as in Kisuk et al, 2017). Defaults to `False`. """ def __init__( self, ndim, convtranspose, in_size, out_size, in_size_bridge, z_down, yx_down, up_mode, conv, k_size, act=None, norm="none", skip_k_size: int | tuple[int, ...] = 1, skip_norm="none", dropout=0, se_block=False, extra_conv=False, ): """ Initialize a Residual Upsampling block, typically used in the decoder path of U-Net-like architectures. This block performs an upsampling operation on the input feature map, concatenates it with a corresponding skip connection (bridge) from the encoder path, and then processes the combined features through a `ResConvBlock`. It supports different upsampling modes and integrates residual connections for improved feature propagation. Parameters ---------- ndim : int Number of dimensions of the input data (2 for 2D, 3 for 3D). convtranspose : Type[nn.ConvTranspose2d | nn.ConvTranspose3d] The transpose convolutional layer type to use if `up_mode` is 'convtranspose'. in_size : int Number of input channels to the upsampling operation (from the previous decoder stage). out_size : int Number of output channels for the final `ResConvBlock` in this upsampling stage. in_size_bridge : int Number of channels of the skip connection (bridge) tensor from the encoder path. z_down : int, optional Downsampling factor applied in the z-dimension for 3D data during upsampling. Only relevant if `ndim` is 3. Defaults to 2. yx_down : int, optional Downsampling factor applied in the y and x dimensions for 2D and 3D data during upsampling. Only relevant if `ndim` is 2 or 3. Defaults to 2. up_mode : str The upsampling mode to use. - 'convtranspose': Uses a transpose convolution (`convtranspose`) for upsampling. - 'upsampling': Uses `nn.Upsample` (bilinear for 2D, trilinear for 3D) followed by a 1x1 convolution. conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use within the internal `ResConvBlock`. k_size : int or tuple Kernel size for the convolutional layers within the `ResConvBlock`. act : str, optional Activation function to use within the `ResConvBlock`. Defaults to `None`. norm : str, optional Normalization layer type to use within the `ResConvBlock`. Options include `'bn'`, `'sync_bn'`, `'in'`, `'gn'`, or `'none'`. Defaults to "none". skip_k_size : int, optional Kernel size for the skip connection convolution within the `ResConvBlock`. Used in ResUNet++. Defaults to 1. skip_norm : str, optional Normalization layer type for the skip connection within the `ResConvBlock`. Defaults to "none". dropout : float, optional Dropout value to be fixed within the `ResConvBlock`. Defaults to 0. se_block : bool, optional Whether to add Squeeze-and-Excitation blocks within the `ResConvBlock`. Defaults to `False`. extra_conv : bool, optional Whether to add an extra convolutional layer before the residual block within the `ResConvBlock` (as in Kisuk et al, 2017). Defaults to `False`. """ super(ResUpBlock, self).__init__() self.ndim = ndim mpool = (z_down, yx_down, yx_down) if ndim == 3 else (yx_down, yx_down) if up_mode == "convtranspose": self.up = convtranspose(in_size, in_size, kernel_size=mpool, stride=mpool) elif up_mode == "upsampling": self.up = nn.Upsample(mode="bilinear" if ndim == 2 else "trilinear", scale_factor=mpool) self.conv_block = ResConvBlock( conv=conv, in_size=in_size + in_size_bridge, out_size=out_size, k_size=k_size, act=act, norm=norm, dropout=dropout, skip_k_size=skip_k_size, skip_norm=skip_norm, se_block=se_block, extra_conv=extra_conv, )
[docs] def forward(self, x, bridge): """ Perform the forward pass of the Residual Upsampling block. First, it upsamples the input tensor `x`. Then, it concatenates the upsampled tensor with the `bridge` tensor (skip connection) along the channel dimension. Finally, the combined tensor is passed through a `ResConvBlock`. Parameters ---------- x : torch.Tensor The input feature tensor from the previous decoder stage. Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). bridge : torch.Tensor The skip connection tensor from the corresponding encoder stage. Expected shape: (batch_size, in_size_bridge, D', H', W'), where D', H', W' match the spatial dimensions after upsampling `x`. Returns ------- torch.Tensor The output tensor of the upsampling block. Its shape will be (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), where D', H', W' are the upsampled spatial dimensions. """ up = self.up(x) out = torch.cat([up, bridge], 1) out = self.conv_block(out) return out
[docs] class HRBasicBlock(nn.Module): """ Implements a Basic block for High-Resolution Networks (HRNet). This block serves as a fundamental building block in HRNet architectures, designed to maintain high-resolution feature representations throughout the network. It consists of two convolutional layers with a residual connection. Reference: `High-Resolution Representations for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use (e.g., `nn.Conv2d` for 2D, `nn.Conv3d` for 3D). in_size : int Number of input feature channels to the block. out_size : int Number of output feature channels for the convolutional layers within the block. The final output channels of the block will also be `out_size` (since `expansion` is 1). stride : int, optional Stride for the first convolutional layer (`conv1_block`). Defaults to 1. act : Optional[nn.Module], optional Activation layer to apply after the first convolution (`conv1_block`). If `None`, no activation is applied. Defaults to `None`. norm : str, optional Normalization layer type to use within the `ConvBlock` components. Options include `'bn'` (BatchNorm), `'sync_bn'` (SyncBatchNorm), `'in'` (InstanceNorm), `'gn'` (GroupNorm), or `'none'` (no normalization). Defaults to "none". dropout : int, optional Dropout rate to apply within the `ConvBlock` components. If 0, no dropout is applied. Defaults to 0. downsample : Optional[nn.Module], optional An optional downsampling layer to apply to the residual connection if the input `in_size` and `out_size` do not match, or if `stride > 1`. Defaults to `None`. """ expansion = 1 def __init__( self, conv: Type[nn.Conv2d | nn.Conv3d], in_size: int, out_size: int, stride: int = 1, act: Optional[nn.Module] = None, norm: str = "none", dropout: int = 0, downsample: Optional[nn.Module] = None, ): """ Initialize the HRBasicBlock. Configures two convolutional layers with optional normalization, activation, and dropout, along with a residual connection. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use. in_size : int Number of input feature channels. out_size : int Number of output feature channels. stride : int, optional Stride for the first convolution. Defaults to 1. act : Optional[nn.Module], optional Activation layer for the first convolution. Defaults to `None`. norm : str, optional Normalization layer type. Defaults to "none". dropout : int, optional Dropout value. Defaults to 0. downsample : Optional[nn.Module], optional Downsample layer for the residual connection. Defaults to `None`. """ super(HRBasicBlock, self).__init__() self.conv1_block = ConvBlock( conv=conv, in_size=in_size, out_size=out_size, k_size=3, padding=1, stride=stride, act=act, norm=norm, dropout=dropout, bias=False, ) self.conv2_block = ConvBlock( conv=conv, in_size=out_size, out_size=out_size, k_size=3, padding=1, stride=1, act=None, norm=norm, dropout=dropout, bias=False, ) self.relu_in = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): """ Perform the forward pass through the HRBasicBlock. Processes the input through two convolutional layers and adds it to a residual connection. An optional downsampling layer is applied to the residual if necessary. Parameters ---------- x : torch.Tensor The input feature tensor. Its shape should be (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). Returns ------- torch.Tensor The output tensor after processing through the basic block and applying the residual connection. Its shape will be (batch_size, out_size, D', H', W') or (batch_size, out_size, H', W'), where D', H', W' depend on the stride. """ residual = x out = self.conv1_block(x) out = self.conv2_block(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu_in(out) return out
[docs] class HRBottleneck(nn.Module): """ Implements the Bottleneck block for High-Resolution Networks (HRNet). This block is a building component of HRNet architectures, designed to efficiently process features by reducing and then expanding the channel dimensions, while incorporating a residual connection. It maintains a high-resolution representation throughout the network. Reference: `High-Resolution Representations for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] Convolutional layer to use in the residual block. in_size : int Input feature maps of the convolutional layers. out_size : int Output feature maps of the convolutional layers. stride : int, optional Stride of the convolutional layers. Default is 1. act : Optional[nn.Module], optional Activation layer to use. Default is None, which means no activation layer is applied. norm : str, optional Normalization layer (one of ``'bn'``, ``'sync_bn'`, ``'in'``, ``'gn'`` or ``'none'``). Default is "none". dropout : int, optional Dropout value to be fixed. Default is 0, which means no dropout is applied. downsample : Optional[nn.Module], optional Downsample layer to apply if the input and output sizes do not match. Default is None. """ expansion = 4 def __init__( self, conv: Type[nn.Conv2d | nn.Conv3d], in_size: int, out_size: int, stride: int = 1, act: Optional[nn.Module] = None, norm: str = "none", dropout: int = 0, downsample: Optional[nn.Module] = None, ): """ Initialize the HRBottleneck block. Configures three convolutional layers (1x1, 3x3, 1x1) with optional normalization, activation, and dropout, along with a residual connection. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use. in_size : int Number of input feature channels. out_size : int Number of output feature channels for the internal convolutions. stride : int, optional Stride for the 3x3 convolution. Defaults to 1. act : Optional[nn.Module], optional Activation layer for the final convolution. Defaults to `None`. norm : str, optional Normalization layer type. Defaults to "none". dropout : int, optional Dropout value. Defaults to 0. downsample : Optional[nn.Module], optional Downsample layer for the residual connection. Defaults to `None`. """ super(HRBottleneck, self).__init__() self.conv1_block = ConvBlock( conv=conv, in_size=in_size, out_size=out_size, k_size=1, padding=0, stride=1, act=None, norm=norm, dropout=dropout, bias=False, ) self.conv2_block = ConvBlock( conv=conv, in_size=out_size, out_size=out_size, k_size=3, padding=1, stride=stride, act=None, norm=norm, dropout=dropout, bias=False, ) self.conv3_block = ConvBlock( conv=conv, in_size=out_size, out_size=out_size * 4, k_size=1, padding=0, stride=1, act=act, norm=norm, dropout=dropout, bias=False, ) self.relu_in = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): """ Perform the forward pass through the HRBottleneck block. Processes the input through a sequence of three convolutional layers and adds it to a residual connection. An optional downsampling layer is applied to the residual if necessary. Parameters ---------- x : torch.Tensor The input feature tensor. Its shape should be (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). Returns ------- torch.Tensor The output tensor after processing through the bottleneck block and applying the residual connection. Its shape will be (batch_size, out_size * expansion, D', H', W') or (batch_size, out_size * expansion, H', W'), where D', H', W' depend on the stride. """ residual = x out = self.conv1_block(x) out = self.conv2_block(out) out = self.conv3_block(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu_in(out) return out
[docs] def get_activation(activation: str = "relu") -> nn.Module: """ Get the specified activation layer. Parameters ---------- activation : str, optional One of ``'relu'``, ``'tanh'``, ``'leaky_relu'``, ``'elu'``, ``'gelu'``, ``'silu'``, ``'sigmoid'``, ``'softmax'``,``'swish'``, 'efficient_swish'``, ``'linear'``, ``'softplus'`` and ``'none'``. """ assert activation in [ "relu", "tanh", "leaky_relu", "elu", "gelu", "silu", "sigmoid", "softmax", "linear", "softplus", "none", ], "Get unknown activation key {}".format(activation) activation_dict = { "relu": nn.ReLU(inplace=True), "tanh": nn.Tanh(), "leaky_relu": nn.LeakyReLU(inplace=True), "elu": nn.ELU(alpha=1.0, inplace=True), "gelu": nn.GELU(), "silu": nn.SiLU(inplace=True), "sigmoid": nn.Sigmoid(), "softmax": nn.Softmax(dim=1), "linear": nn.Identity(), "softplus": nn.Softplus(), "none": nn.Identity(), } return activation_dict[activation]
[docs] def prepare_activation_layers( activations: List[str], output_channel_info: List[str], output_channels: List[int] ) -> Tuple[nn.ModuleList, Optional[nn.ModuleList]]: """ Prepare activation layers for the output and classification heads. Parameters ---------- activations : List[str] A list of activation function names. output_channel_info : List[str] A list of strings indicating the type of output channels. output_channels : List[int] A list of integers indicating the number of channels for each output head. Returns ------- out_activations : nn.ModuleList A ModuleList containing the activation layers for the output head. class_activation : nn.ModuleList or None A ModuleList containing the activation layers for the classification head, or None if not provided. """ activation_list = [] class_activation_list = [] all_channel_info = [] for i, c_info in enumerate(output_channel_info): for j in range(output_channels[i]): all_channel_info.append(c_info) for i, c_info in enumerate(all_channel_info): activation = activations[i].lower().removeprefix("ce_") act = get_activation(activation.lower()) if "class" in c_info: class_activation_list.append(act) else: activation_list.append(act) # We break the loop after finding the fist softmax activation since we assume that there is only one # softmax activation for the classification head (if any) if "softmax" in activation: break_outer = True break if len(class_activation_list) > 0: return nn.ModuleList(activation_list), nn.ModuleList(class_activation_list) else: return nn.ModuleList(activation_list), None
[docs] def get_norm_3d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Module: """ Get the specified normalization layer for a 3D model. Code adapted from Pytorch for Connectomics: `<https://github.com/zudi-lin/pytorch_connectomics/blob/6fbd5457463ae178ecd93b2946212871e9c617ee/connectomics/model/utils/misc.py#L330-L408>`_. Args: norm (str): one of ``'bn'``, ``'sync_bn'`` ``'in'``, ``'gn'`` or ``'none'``. out_channels (int): channel number. bn_momentum (float): the momentum of normalization layers. Returns: nn.Module: the normalization layer """ assert norm in [ "bn", "sync_bn", "gn", "in", "none", ], "Get unknown normalization layer key {}".format(norm) selected_norm = { "bn": nn.BatchNorm3d, "sync_bn": nn.SyncBatchNorm, "in": nn.InstanceNorm3d, "gn": nn.GroupNorm, "none": nn.Identity, }[norm] if norm in ["bn", "sync_bn"]: return selected_norm(out_channels, momentum=bn_momentum) elif norm == "in": return selected_norm(out_channels, affine=True, momentum=bn_momentum) elif norm == "gn": return selected_norm(out_channels, num_groups=8) else: return selected_norm()
[docs] def get_norm_2d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Module: """ Get the specified normalization layer for a 2D model. Code adapted from Pytorch for Connectomics: `<https://github.com/zudi-lin/pytorch_connectomics/blob/6fbd5457463ae178ecd93b2946212871e9c617ee/connectomics/model/utils/misc.py#L330-L408>`_. Args: norm (str): one of ``'bn'``, ``'sync_bn'`` ``'in'``, ``'gn'`` or ``'none'``. out_channels (int): channel number. bn_momentum (float): the momentum of normalization layers. Returns: nn.Module: the normalization layer """ assert norm in [ "bn", "sync_bn", "gn", "in", "none", ], "Get unknown normalization layer key {}".format(norm) selected_norm = { "bn": nn.BatchNorm2d, "sync_bn": nn.SyncBatchNorm, "in": nn.InstanceNorm2d, "gn": nn.GroupNorm, "none": nn.Identity, }[norm] if norm in ["bn", "sync_bn"]: return selected_norm(out_channels, momentum=bn_momentum) elif norm == "in": return selected_norm(out_channels, affine=True, momentum=bn_momentum) elif norm == "gn": return selected_norm(out_channels, num_groups=16) else: return selected_norm()
[docs] class ResUNetPlusPlus_AttentionBlock(nn.Module): """ Implements an attention block as used in the ResUNet++ architecture. This block is designed to refine skip connections in a U-Net-like architecture by selectively emphasizing relevant features. It combines information from both the encoder (downsampling path) and decoder (upsampling path) to generate an attention map, which is then applied to the decoder's input. Reference: Adapted from `here <https://github.com/rishikksh20/ResUnet/blob/master/core/modules.py>`_. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use (e.g., `nn.Conv2d` for 2D, `nn.Conv3d` for 3D). maxpool : Type[nn.MaxPool2d | nn.MaxPool3d] The max-pooling layer type to use. input_encoder : int Number of input channels from the encoder path (larger feature map). input_decoder : int Number of input channels from the decoder path (upsampled feature map, skip connection). output_dim : int The desired number of output channels for the internal convolutional layers within the attention block. z_down : int, optional Downsampling factor for the z-dimension (depth) in 3D max-pooling. Only relevant if `conv` is `nn.Conv3d`. Defaults to 2. yx_down : int, optional Downsampling factor for the y and x dimensions in 2D and 3D max-pooling. Defaults to 2. norm : str, optional Normalization layer type to use within the convolutional sub-blocks. Options include `'bn'` (BatchNorm), `'sync_bn'` (SyncBatchNorm), `'in'` (InstanceNorm), `'gn'` (GroupNorm), or `'none'` (no normalization). Defaults to "none". """ def __init__( self, conv, maxpool, input_encoder, input_decoder, output_dim, z_down=2, yx_down=2, norm="none", ): """ Initialize the ResUNetPlusPlus_AttentionBlock. Sets up convolutional paths for processing encoder and decoder inputs, followed by a combined attention convolution. Parameters ---------- conv : Type[nn.Conv2d | nn.Conv3d] The convolutional layer type to use. maxpool : Type[nn.MaxPool2d | nn.MaxPool3d] The max-pooling layer type to use. input_encoder : int Number of input channels from the encoder path. input_decoder : int Number of input channels from the decoder path (skip connection). output_dim : int The desired number of channels for the intermediate feature maps. z_down : int, optional Downsampling factor for the z-dimension in 3D max-pooling. Defaults to 2. yx_down : int, optional Downsampling factor for the y and x dimensions in 2D and 3D max-pooling. Defaults to 2. norm : str, optional Normalization layer type to use within the convolutional blocks. Defaults to "none". """ super(ResUNetPlusPlus_AttentionBlock, self).__init__() block = [] if norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(norm, input_encoder)) else: block.append(get_norm_3d(norm, input_encoder)) block += [ nn.ReLU(), conv(input_encoder, output_dim, 3, padding=1), maxpool((yx_down, yx_down)) if conv == nn.Conv2d else maxpool((z_down, yx_down, yx_down)), ] self.conv_encoder = nn.Sequential(*block) block = [] if norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(norm, input_decoder)) else: block.append(get_norm_3d(norm, input_decoder)) block += [nn.ReLU(), conv(input_decoder, output_dim, 3, padding=1)] self.conv_decoder = nn.Sequential(*block) block = [] if norm != "none": if conv == nn.Conv2d: block.append(get_norm_2d(norm, output_dim)) else: block.append(get_norm_3d(norm, output_dim)) block += [nn.ReLU(), conv(output_dim, 1, 1)] self.conv_attn = nn.Sequential(*block)
[docs] def forward(self, x1, x2): """ Perform the forward pass of the attention block. It processes inputs from the encoder (`x1`) and decoder (`x2`), sums them, applies an attention convolution, and then multiplies the resulting attention map with the decoder input (`x2`). Parameters ---------- x1 : torch.Tensor The input tensor from the encoder path (downsampled feature map). Expected shape: (batch_size, input_encoder, D, H, W) or (batch_size, input_encoder, H, W). x2 : torch.Tensor The input tensor from the decoder path (upsampled feature map, skip connection). Expected shape: (batch_size, input_decoder, D, H, W) or (batch_size, input_decoder, H, W). Returns ------- torch.Tensor The attended output tensor, with the same shape as `x2`. """ out = self.conv_encoder(x1) + self.conv_decoder(x2) out = self.conv_attn(out) return out * x2
[docs] def init_weights(model: nn.Module): """ Applies Xavier initialization to the model. If the model has `synapse_det=True` and a `heads` module list, the first head (heatmap) gets a CenterNet-style focal bias of -4.59. """ # 1. Safely locate the heatmap head directly from the model instance output_channel_info = getattr(model, "output_channel_info", [""]) heads = getattr(model, "heads", None) # We assume heads[0] is the heatmap if synapse_det is True hm_head = None if output_channel_info and "bbox_heatmap" in output_channel_info and heads is not None and len(heads) > 0: hm_head = heads[output_channel_info.index("bbox_heatmap")] # 2. Define the exact initialization logic def _apply_init(m): if isinstance(m, (nn.Conv2d, nn.Conv3d)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: # Apply the prior probability bias strictly to the heatmap head if hm_head is not None and m is hm_head: nn.init.constant_(m.bias, -4.59) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) # 3. Apply it to the model model.apply(_apply_init)