biapy.models.blocksο
This module contains a collection of fundamental building blocks for convolutional neural networks, primarily designed for biomedical image segmentation architectures like U-Nets and their variants.
It provides modular components for various operations, including:
Convolutional Layers: Basic ConvBlock and DoubleConvBlock for standard feature extraction.
Attention Mechanisms: AttentionBlock to integrate attention gating into skip connections.
Squeeze-and-Excitation Networks: SqExBlock for channel-wise feature recalibration.
ConvNeXt Blocks: Both ConvNeXtBlock_V1 and ConvNeXtBlock_V2 for modern, efficient feature processing with depthwise convolutions, layer normalization, and residual connections.
Global Response Normalization: GRN layer used within ConvNeXt V2 for enhanced feature discrimination.
Upsampling Blocks: UpBlock, UpConvNeXtBlock_V1, and UpConvNeXtBlock_V2 for decoder paths, handling upsampling, skip connection concatenation, and feature refinement.
The blocks are designed to be flexible, supporting both 2D and 3D operations, various normalization types, activations, and configurable parameters like kernel sizes and dropout.
- class biapy.models.blocks.ConvBlock(conv, in_size, out_size, k_size, padding: int | str = 'same', stride=1, bias=True, act=None, norm='none', dropout=0, se_block=False)[source]ο
Bases:
ModuleImplements a standard Convolutional Block.
This block consists of a convolutional layer followed by optional normalization, activation, dropout, and a Squeeze-and-Excitation (SE) block. It serves as a versatile building component in various convolutional neural network architectures.
- forward(x)[source]ο
Perform the forward pass of the Convolutional Block.
Processes the input tensor sequentially through the defined layers: convolution, optional normalization, optional activation, optional dropout, and optional Squeeze-and-Excitation.
- Parameters:
x (torch.Tensor) β The input feature tensor. Expected shape for 2D: (batch_size, in_size, height, width). Expected shape for 3D: (batch_size, in_size, depth, height, width).
- Returns:
The output tensor after passing through the block. Its shape will be (batch_size, out_size, Hβ, Wβ) or (batch_size, out_size, Dβ, Hβ, Wβ), where Hβ, Wβ (and Dβ) depend on padding and stride.
- Return type:
torch.Tensor
- class biapy.models.blocks.DoubleConvBlock(conv, in_size, out_size, k_size, act=None, norm='none', dropout=0, se_block=False)[source]ο
Bases:
ModuleImplements a Double Convolutional Block.
This block consists of two sequential
ConvBlocklayers. It is a common building component in many convolutional neural network architectures, especially in U-Net-like models, to extract features.- forward(x)[source]ο
Perform the forward pass of the Double Convolutional Block.
Processes the input tensor sequentially through the two ConvBlock layers.
- Parameters:
x (torch.Tensor) β The input feature tensor. Expected shape for 2D: (batch_size, in_size, height, width). Expected shape for 3D: (batch_size, in_size, depth, height, width).
- Returns:
The output tensor after passing through both ConvBlock layers. Its shape will be (batch_size, out_size, Hβ, Wβ) or (batch_size, out_size, Dβ, Hβ, Wβ), where Hβ, Wβ (and Dβ) match the input spatial dimensions if padding is βsameβ.
- Return type:
torch.Tensor
- class biapy.models.blocks.ConvNeXtBlock_V1(ndim, conv, dim, layer_scale=1e-06, stochastic_depth_prob=0.0, layer_norm=None, k_size=7)[source]ο
Bases:
ModuleImplements a single ConvNeXt V1 block.
This block is a fundamental building component of ConvNeXt V1 networks, featuring a depthwise convolution, a LayerNorm-Linear-GELU-Linear path, layer scaling, and a stochastic depth residual connection.
- forward(x)[source]ο
Perform the forward pass of the ConvNeXt V1 block.
Processes the input through a depthwise convolution, layer normalization, and an MLP with GELU. The output of this path is optionally scaled by a learnable layer scale parameter and then added to the original input via a residual connection, applying stochastic depth.
- Parameters:
x (torch.Tensor) β The input feature tensor. Expected shape for 2D: (batch_size, dim, height, width). Expected shape for 3D: (batch_size, dim, depth, height, width).
- Returns:
The output tensor of the ConvNeXt V1 block, with the same shape as the input.
- Return type:
torch.Tensor
- class biapy.models.blocks.GRN(dim)[source]ο
Bases:
ModuleImplement the Global Response Normalization (GRN) layer.
This layer enhances feature discrimination by normalizing features based on global responses across channels, as introduced in ConvNeXt V2. It includes learnable parameters for scaling and shifting.
- forward(x)[source]ο
Perform the forward pass of the GRN layer.
Calculates the L2 norm (Gx) across spatial dimensions for each channel. Then, it normalizes Gx by its mean across channels (Nx). Finally, it applies the learnable gamma and beta parameters to x * Nx and adds the original input x as a residual connection.
- Parameters:
x (torch.Tensor) β The input feature tensor. Expected to be in a channel-last format for proper normalization (e.g., [B, D, H, W, C] for 3D, or [B, H, W, C] for 2D) when gamma and beta are applied. Assuming the input x is permuted to (B, spatial_dimsβ¦, C) before this layer.
- Returns:
The normalized feature tensor, with the same shape as the input x.
- Return type:
torch.Tensor
- class biapy.models.blocks.ConvNeXtBlock_V2(ndim, conv, dim, stochastic_depth_prob=0.0, layer_norm=None, k_size=7)[source]ο
Bases:
ModuleImplements a single ConvNeXt V2 block.
This block is a fundamental building component of ConvNeXt V2 networks, featuring a depthwise convolution, a Permute-LayerNorm-Linear-GELU-GRN-Linear-Permute path, and a stochastic depth residual connection.
- forward(x)[source]ο
Perform the forward pass of the ConvNeXt V2 block.
Processes the input through a depthwise convolution, layer normalization, and an MLP with GELU and GRN. The output of this path is then added to the original input via a residual connection, optionally applying stochastic depth.
- Parameters:
x (torch.Tensor) β The input feature tensor. Expected shape for 2D: (batch_size, dim, height, width). Expected shape for 3D: (batch_size, dim, depth, height, width).
- Returns:
The output tensor of the ConvNeXt V2 block, with the same shape as the input.
- Return type:
torch.Tensor
- class biapy.models.blocks.UpBlock(ndim, convtranspose, in_size, out_size, z_down, yx_down, up_mode, conv, k_size, act=None, norm='none', dropout=0, attention_gate=False, se_block=False)[source]ο
Bases:
ModuleImplements a standard Upsampling block, commonly used in the decoder path of U-Net-like architectures.
This block performs an upsampling operation, concatenates the upsampled features with a skip connection (bridge) from the encoder, and then processes the combined features through a DoubleConvBlock. It supports different upsampling modes and optional attention gating.
- forward(x, bridge)[source]ο
Perform the forward pass of the Upsampling block.
First, it upsamples the input tensor x. If an attention gate is enabled, it uses the upsampled x and the bridge tensor to compute attention, then concatenates the upsampled x with the (potentially attended) bridge. Finally, the concatenated tensor is processed by a DoubleConvBlock.
- Parameters:
x (torch.Tensor) β The input feature tensor from the previous decoder stage (lower resolution). Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
bridge (torch.Tensor) β The skip connection tensor from the corresponding encoder stage (higher resolution). Expected shape: (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), where Dβ, Hβ, Wβ match the spatial dimensions after upsampling x.
- Returns:
The output tensor of the upsampling block. Its shape will be (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), matching the upsampled spatial dimensions and out_size channels.
- Return type:
torch.Tensor
- class biapy.models.blocks.UpConvNeXtBlock_V1(ndim, convtranspose, in_size, out_size, z_down, yx_down, up_mode, conv, attention_gate=False, se_block=False, cn_layers=1, sd_probs=[0.0], layer_scale=1e-06, layer_norm=None, k_size=7)[source]ο
Bases:
ModuleImplements an Upsampling block using ConvNeXt V1 components.
This block is designed for the upsampling path of U-Net-like architectures, combining upsampling with concatenation of skip connections, optional attention gating, and a sequence of ConvNeXt V1 blocks for feature refinement.
- forward(x, bridge)[source]ο
Perform the forward pass of the UpConvNeXtBlock_V1.
First, it upsamples the input tensor x. If an attention gate is enabled, it uses the upsampled x and the bridge tensor to compute attention, then concatenates the upsampled x with the (potentially attended) bridge. Finally, the concatenated tensor is processed by an initial convolutional block and then refined through a sequence of ConvNeXt V1 blocks.
- Parameters:
x (torch.Tensor) β The input feature tensor from the previous decoder stage (lower resolution). Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
bridge (torch.Tensor) β The skip connection tensor from the corresponding encoder stage (higher resolution). Expected shape: (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), where Dβ, Hβ, Wβ match the spatial dimensions after upsampling x.
- Returns:
The output tensor of the upsampling block. Its shape will be (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), matching the upsampled spatial dimensions and out_size channels.
- Return type:
torch.Tensor
- class biapy.models.blocks.UpConvNeXtBlock_V2(ndim, convtranspose, in_size, out_size, z_down, yx_down, up_mode, conv, attention_gate=False, se_block=False, cn_layers=1, sd_probs=[0.0], layer_norm=None, k_size=7)[source]ο
Bases:
ModuleImplements an Upsampling block using ConvNeXt V2 components.
This block is designed for the upsampling path of U-Net-like architectures, combining upsampling with concatenation of skip connections, optional attention gating, and a sequence of ConvNeXt V2 blocks for feature refinement.
- forward(x, bridge)[source]ο
Perform the forward pass of the UpConvNeXtBlock_V2.
First, it upsamples the input tensor x. If an attention gate is enabled, it uses the upsampled x and the bridge tensor to compute attention, then concatenates the upsampled x with the (potentially attended) bridge. Finally, the concatenated tensor is processed by an initial convolutional block and then refined through a sequence of ConvNeXt V2 blocks.
- Parameters:
x (torch.Tensor) β The input feature tensor from the previous decoder stage (lower resolution). Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
bridge (torch.Tensor) β The skip connection tensor from the corresponding encoder stage (higher resolution). Expected shape: (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), where Dβ, Hβ, Wβ match the spatial dimensions after upsampling x.
- Returns:
The output tensor of the upsampling block. Its shape will be (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), matching the upsampled spatial dimensions and out_size channels.
- Return type:
torch.Tensor
- class biapy.models.blocks.AttentionBlock(conv, in_size, out_size, norm='none')[source]ο
Bases:
ModuleImplements an Attention Block, as proposed in Attention U-Net.
This block refines skip connections in U-Net-like architectures by generating attention coefficients. It learns to focus on salient features from the skip pathway (x) guided by the features from the coarser, upsampled pathway (g).
Reference: Attention U-Net: Learning Where to Look for the Pancreas.
- forward(g, x)[source]ο
Perform the forward pass of the Attention Block.
Processes the gating signal g and the skip connection x independently through convolutional layers. Their outputs are summed, passed through a ReLU activation, and then a 1x1 convolution with Sigmoid activation generates the attention coefficients (psi). Finally, these attention coefficients are multiplied element-wise with the skip connection x to produce the attention-gated output.
- Parameters:
g (torch.Tensor) β The gating signal tensor from the coarser (upsampled) pathway. Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
x (torch.Tensor) β The skip connection tensor from the corresponding encoder pathway. Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W). Spatial dimensions must match those of g after any necessary upsampling of g.
- Returns:
The attention-gated feature tensor, with the same shape as x.
- Return type:
torch.Tensor
- class biapy.models.blocks.SqExBlock(c, r=16, ndim=2)[source]ο
Bases:
ModuleImplements the Squeeze-and-Excitation (SE) block, a computational unit that adaptively recalibrates channel-wise feature responses.
This block enhances the representational power of a network by explicitly modeling interdependencies between channels, allowing the network to perform feature recalibration.
Reference: Squeeze and Excitation Networks. Credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4
- forward(x)[source]ο
Perform the forward pass of the Squeeze-and-Excitation block.
Applies a global average pooling (squeeze) to the input to aggregate spatial information into a channel descriptor. This descriptor is then passed through a fully connected excitation network to predict channel-wise attention weights. Finally, these weights are applied to the input feature map by channel-wise multiplication.
- Parameters:
x (torch.Tensor) β The input feature tensor. Expected shape for 2D: (batch_size, channels, height, width). Expected shape for 3D: (batch_size, channels, depth, height, width).
- Returns:
The recalibrated feature tensor, with the same shape as the input x.
- Return type:
torch.Tensor
- class biapy.models.blocks.ResConvBlock(conv, in_size, out_size, k_size, act=None, norm='none', dropout=0, skip_k_size: int | Tuple[int, ...] = 1, skip_norm='none', first_block=False, se_block=False, extra_conv=False)[source]ο
Bases:
ModuleImplements a Residual Convolutional Block.
This block is a core component often used in U-Net-like architectures to build encoder and decoder paths. It consists of a sequence of convolutional layers with a skip connection, allowing for better gradient flow and feature reuse. It supports optional pre-activation, Squeeze-and-Excitation blocks, and an initial extra convolutional layer.
- forward(x)[source]ο
Perform the forward pass through the Residual Convolutional Block.
Processes the input tensor through an optional pre-convolutional layer, then through the main convolutional blocks, and finally adds a skip connection. An optional Squeeze-and-Excitation block is applied at the end.
- Parameters:
x (torch.Tensor) β The input feature tensor. Its shape should be (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
- Returns:
The output tensor after processing through the residual block. Its shape will be (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), where Dβ, Hβ, Wβ match the input spatial dimensions if padding=βsameβ is used.
- Return type:
torch.Tensor
- class biapy.models.blocks.ResUpBlock(ndim, convtranspose, in_size, out_size, in_size_bridge, z_down, yx_down, up_mode, conv, k_size, act=None, norm='none', skip_k_size: int | tuple[int, ...] = 1, skip_norm='none', dropout=0, se_block=False, extra_conv=False)[source]ο
Bases:
ModuleImplements a Residual Upsampling block, typically used in the decoder path of U-Net-like architectures.
This block performs an upsampling operation on the input feature map, concatenates it with a corresponding skip connection (bridge) from the encoder path, and then processes the combined features through a ResConvBlock. It supports different upsampling modes and integrates residual connections for improved feature propagation.
- Parameters:
ndim (int) β Number of dimensions of the input data (2 for 2D, 3 for 3D).
convtranspose (Type[nn.ConvTranspose2d | nn.ConvTranspose3d]) β The transpose convolutional layer type to use if up_mode is βconvtransposeβ.
in_size (int) β Number of input channels to the upsampling operation (from the previous decoder stage).
out_size (int) β Number of output channels for the final ResConvBlock in this upsampling stage.
in_size_bridge (int) β Number of channels of the skip connection (bridge) tensor from the encoder path.
z_down (int, optional) β Downsampling factor applied in the z-dimension for 3D data during upsampling. Only relevant if ndim is 3. Defaults to 2.
yx_down (int, optional) β Downsampling factor applied in the y and x dimensions for 2D and 3D data during upsampling. Only relevant if ndim is 2 or 3. Defaults to 2.
up_mode (str) β The upsampling mode to use.
βconvtransposeβ: Uses a transpose convolution (convtranspose) for upsampling.
βupsamplingβ: Uses nn.Upsample (bilinear for 2D, trilinear for 3D) followed by a 1x1 convolution.
conv (Type[nn.Conv2d | nn.Conv3d]) β The convolutional layer type to use within the internal ResConvBlock.
k_size (int or tuple) β Kernel size for the convolutional layers within the ResConvBlock.
act (str, optional) β Activation function to use within the ResConvBlock. Defaults to None.
norm (str, optional) β Normalization layer type to use within the ResConvBlock. Options include βbnβ, βsync_bnβ, βinβ, βgnβ, or βnoneβ. Defaults to βnoneβ.
skip_k_size (int or tuple, optional) β Kernel size for the skip connection convolution within the ResConvBlock. Used in ResUNet++. Defaults to 1.
skip_norm (str, optional) β Normalization layer type for the skip connection within the ResConvBlock. Defaults to βnoneβ.
dropout (float, optional) β Dropout value to be fixed within the ResConvBlock. Defaults to 0.
se_block (bool, optional) β Whether to add Squeeze-and-Excitation blocks within the ResConvBlock. Defaults to False.
extra_conv (bool, optional) β Whether to add an extra convolutional layer before the residual block within the ResConvBlock (as in Kisuk et al, 2017). Defaults to False.
- forward(x, bridge)[source]ο
Perform the forward pass of the Residual Upsampling block.
First, it upsamples the input tensor x. Then, it concatenates the upsampled tensor with the bridge tensor (skip connection) along the channel dimension. Finally, the combined tensor is passed through a ResConvBlock.
- Parameters:
x (torch.Tensor) β The input feature tensor from the previous decoder stage. Expected shape: (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
bridge (torch.Tensor) β The skip connection tensor from the corresponding encoder stage. Expected shape: (batch_size, in_size_bridge, Dβ, Hβ, Wβ), where Dβ, Hβ, Wβ match the spatial dimensions after upsampling x.
- Returns:
The output tensor of the upsampling block. Its shape will be (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), where Dβ, Hβ, Wβ are the upsampled spatial dimensions.
- Return type:
torch.Tensor
- class biapy.models.blocks.HRBasicBlock(conv: Type[Conv2d | Conv3d], in_size: int, out_size: int, stride: int = 1, act: Module | None = None, norm: str = 'none', dropout: int = 0, downsample: Module | None = None)[source]ο
Bases:
ModuleImplements a Basic block for High-Resolution Networks (HRNet).
This block serves as a fundamental building block in HRNet architectures, designed to maintain high-resolution feature representations throughout the network. It consists of two convolutional layers with a residual connection.
- Parameters:
conv (Type[nn.Conv2d | nn.Conv3d]) β The convolutional layer type to use (e.g., nn.Conv2d for 2D, nn.Conv3d for 3D).
in_size (int) β Number of input feature channels to the block.
out_size (int) β Number of output feature channels for the convolutional layers within the block. The final output channels of the block will also be out_size (since expansion is 1).
stride (int, optional) β Stride for the first convolutional layer (conv1_block). Defaults to 1.
act (Optional[nn.Module], optional) β Activation layer to apply after the first convolution (conv1_block). If None, no activation is applied. Defaults to None.
norm (str, optional) β Normalization layer type to use within the ConvBlock components. Options include βbnβ (BatchNorm), βsync_bnβ (SyncBatchNorm), βinβ (InstanceNorm), βgnβ (GroupNorm), or βnoneβ (no normalization). Defaults to βnoneβ.
dropout (int, optional) β Dropout rate to apply within the ConvBlock components. If 0, no dropout is applied. Defaults to 0.
downsample (Optional[nn.Module], optional) β An optional downsampling layer to apply to the residual connection if the input in_size and out_size do not match, or if stride > 1. Defaults to None.
- expansion = 1ο
- forward(x)[source]ο
Perform the forward pass through the HRBasicBlock.
Processes the input through two convolutional layers and adds it to a residual connection. An optional downsampling layer is applied to the residual if necessary.
- Parameters:
x (torch.Tensor) β The input feature tensor. Its shape should be (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
- Returns:
The output tensor after processing through the basic block and applying the residual connection. Its shape will be (batch_size, out_size, Dβ, Hβ, Wβ) or (batch_size, out_size, Hβ, Wβ), where Dβ, Hβ, Wβ depend on the stride.
- Return type:
torch.Tensor
- class biapy.models.blocks.HRBottleneck(conv: Type[Conv2d | Conv3d], in_size: int, out_size: int, stride: int = 1, act: Module | None = None, norm: str = 'none', dropout: int = 0, downsample: Module | None = None)[source]ο
Bases:
ModuleImplements the Bottleneck block for High-Resolution Networks (HRNet).
This block is a building component of HRNet architectures, designed to efficiently process features by reducing and then expanding the channel dimensions, while incorporating a residual connection. It maintains a high-resolution representation throughout the network.
- Parameters:
conv (Type[nn.Conv2d | nn.Conv3d]) β Convolutional layer to use in the residual block.
in_size (int) β Input feature maps of the convolutional layers.
out_size (int) β Output feature maps of the convolutional layers.
stride (int, optional) β Stride of the convolutional layers. Default is 1.
act (Optional[nn.Module], optional) β Activation layer to use. Default is None, which means no activation layer is applied.
norm (str, optional) β Normalization layer (one of
'bn','sync_bn'`, ``'in','gn'or'none'). Default is βnoneβ.dropout (int, optional) β Dropout value to be fixed. Default is 0, which means no dropout is applied.
downsample (Optional[nn.Module], optional) β Downsample layer to apply if the input and output sizes do not match. Default is None.
- expansion = 4ο
- forward(x)[source]ο
Perform the forward pass through the HRBottleneck block.
Processes the input through a sequence of three convolutional layers and adds it to a residual connection. An optional downsampling layer is applied to the residual if necessary.
- Parameters:
x (torch.Tensor) β The input feature tensor. Its shape should be (batch_size, in_size, D, H, W) or (batch_size, in_size, H, W).
- Returns:
The output tensor after processing through the bottleneck block and applying the residual connection. Its shape will be (batch_size, out_size * expansion, Dβ, Hβ, Wβ) or (batch_size, out_size * expansion, Hβ, Wβ), where Dβ, Hβ, Wβ depend on the stride.
- Return type:
torch.Tensor
- biapy.models.blocks.get_activation(activation: str = 'relu') Module[source]ο
Get the specified activation layer.
- Parameters:
activation (str, optional) β One of
'relu','tanh','leaky_relu','elu','gelu','silu','sigmoid','softmax',``βswishβ, 'efficient_swish','linear','softplus'and'none'.
- biapy.models.blocks.prepare_activation_layers(activations: List[str], output_channel_info: List[str], output_channels: List[int]) Tuple[ModuleList, ModuleList | None][source]ο
Prepare activation layers for the output and classification heads.
- Parameters:
activations (List[str]) β A list of activation function names.
output_channel_info (List[str]) β A list of strings indicating the type of output channels.
output_channels (List[int]) β A list of integers indicating the number of channels for each output head.
- Returns:
out_activations (nn.ModuleList) β A ModuleList containing the activation layers for the output head.
class_activation (nn.ModuleList or None) β A ModuleList containing the activation layers for the classification head, or None if not provided.
- biapy.models.blocks.get_norm_3d(norm: str, out_channels: int, bn_momentum: float = 0.1) Module[source]ο
Get the specified normalization layer for a 3D model.
- Code adapted from Pytorch for Connectomics:
- Args:
norm (str): one of
'bn','sync_bn''in','gn'or'none'. out_channels (int): channel number. bn_momentum (float): the momentum of normalization layers.- Returns:
nn.Module: the normalization layer
- biapy.models.blocks.get_norm_2d(norm: str, out_channels: int, bn_momentum: float = 0.1) Module[source]ο
Get the specified normalization layer for a 2D model.
- Code adapted from Pytorch for Connectomics:
- Args:
norm (str): one of
'bn','sync_bn''in','gn'or'none'. out_channels (int): channel number. bn_momentum (float): the momentum of normalization layers.- Returns:
nn.Module: the normalization layer
- class biapy.models.blocks.ResUNetPlusPlus_AttentionBlock(conv, maxpool, input_encoder, input_decoder, output_dim, z_down=2, yx_down=2, norm='none')[source]ο
Bases:
ModuleImplements an attention block as used in the ResUNet++ architecture.
This block is designed to refine skip connections in a U-Net-like architecture by selectively emphasizing relevant features. It combines information from both the encoder (downsampling path) and decoder (upsampling path) to generate an attention map, which is then applied to the decoderβs input.
- Reference:
Adapted from here.
- Parameters:
conv (Type[nn.Conv2d | nn.Conv3d]) β The convolutional layer type to use (e.g., nn.Conv2d for 2D, nn.Conv3d for 3D).
maxpool (Type[nn.MaxPool2d | nn.MaxPool3d]) β The max-pooling layer type to use.
input_encoder (int) β Number of input channels from the encoder path (larger feature map).
input_decoder (int) β Number of input channels from the decoder path (upsampled feature map, skip connection).
output_dim (int) β The desired number of output channels for the internal convolutional layers within the attention block.
z_down (int, optional) β Downsampling factor for the z-dimension (depth) in 3D max-pooling. Only relevant if conv is nn.Conv3d. Defaults to 2.
yx_down (int, optional) β Downsampling factor for the y and x dimensions in 2D and 3D max-pooling. Defaults to 2.
norm (str, optional) β Normalization layer type to use within the convolutional sub-blocks. Options include βbnβ (BatchNorm), βsync_bnβ (SyncBatchNorm), βinβ (InstanceNorm), βgnβ (GroupNorm), or βnoneβ (no normalization). Defaults to βnoneβ.
- forward(x1, x2)[source]ο
Perform the forward pass of the attention block.
It processes inputs from the encoder (x1) and decoder (x2), sums them, applies an attention convolution, and then multiplies the resulting attention map with the decoder input (x2).
- Parameters:
x1 (torch.Tensor) β The input tensor from the encoder path (downsampled feature map). Expected shape: (batch_size, input_encoder, D, H, W) or (batch_size, input_encoder, H, W).
x2 (torch.Tensor) β The input tensor from the decoder path (upsampled feature map, skip connection). Expected shape: (batch_size, input_decoder, D, H, W) or (batch_size, input_decoder, H, W).
- Returns:
The attended output tensor, with the same shape as x2.
- Return type:
torch.Tensor