Source code for biapy.models.unext_v1

"""
This module implements the U-NeXt architecture (Version 1), a U-Net variant that integrates elements from the ConvNeXt model.

It aims to combine the strong hierarchical feature learning of U-Nets with the modern ConvNeXt
design principles, which are inspired by Vision Transformers but retain
the efficiency and inductive biases of convolutional networks.

U-NeXt_V1 is designed for both 2D and 3D image segmentation tasks. It features
a ConvNeXt-style encoder and decoder, with specialized blocks for downsampling,
upsampling, and the bottleneck. It supports various configurations, including
optional super-resolution, multi-head outputs, and stochastic depth for regularization.

Classes:

- ``U_NeXt_V1``: The main U-NeXt model (Version 1).

This module relies on building blocks defined in `biapy.models.blocks`, such as
`UpConvNeXtBlock_V1`, `ConvNeXtBlock_V1`, and `ProjectionHead`.

References:

- `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28>`_
- `A ConvNet for the 2020s <https://openaccess.thecvf.com/content/CVPR2022/html/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.html>`_.

Image representation:

.. image:: ../../img/models/unext.png
    :width: 100%
    :align: center

"""
import torch
import torch.nn as nn
from torchvision.ops.misc import Permute
from typing import Dict, List

from biapy.models.blocks import UpConvNeXtBlock_V1, ConvNeXtBlock_V1, prepare_activation_layers, init_weights
from biapy.models.heads import ProjectionHead


[docs] class U_NeXt_V1(nn.Module): """ Create 2D/3D U-NeXt (Version 1) model. U-NeXt combines the classic U-Net architecture with modern ConvNeXt blocks, aiming to leverage both the strong hierarchical feature learning of U-Nets and the efficiency and inductive biases of ConvNeXt. It is designed for biomedical image segmentation tasks. Reference: `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28>`_, `A ConvNet for the 2020s <https://openaccess.thecvf.com/content/CVPR2022/html/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.html>`_. """ def __init__( self, image_shape=(256, 256, 1), feature_maps=[32, 64, 128, 256], upsample_layer="convtranspose", z_down=[2, 2, 2, 2], yx_down=[2, 2, 2, 2], output_channels=[1], separated_decoders=False, output_channel_info=["F"], explicit_activations: bool = False, head_activations: List[str] = ["ce_sigmoid"], upsampling_factor=(), upsampling_position="pre", stochastic_depth_prob=0.1, layer_scale=1e-6, cn_layers=[2, 2, 2, 2], isotropy=True, stem_k_size=2, contrast: bool = False, contrast_proj_dim: int = 256, return_one_tensor: bool = False, ): """ Initialize the U-NeXt_V1 model. Sets up the ConvNeXt-style encoder (downsampling path), decoder (upsampling path), stem, bottleneck, and optional super-resolution and multi-head output layers. It dynamically selects 2D or 3D convolutional and normalization layers based on `ndim` and `isotropy` settings. Stochastic depth probabilities are progressively increased across layers. Parameters ---------- image_shape : Tuple[int, ...] Dimensions of the input image. E.g., `(y, x, channels)` for 2D or `(z, y, x, channels)` for 3D. The last element `image_shape[-1]` should be the number of input channels. feature_maps : List[int], optional A list specifying the number of feature maps (channels) at each level of the U-NeXt. The length of this list defines the depth of the network. Defaults to `[32, 64, 128, 256]`. upsample_layer : str, optional Type of layer to use for upsampling in the decoder path. Two options: "convtranspose" (using `nn.ConvTranspose2d`/`3d`) or "upsampling" (using `nn.Upsample` followed by convolution). Defaults to "convtranspose". z_down : List[int], optional For 3D data, a list of downsampling factors for the z-dimension at each pooling stage in the encoder. Set elements to `1` if the dataset is not isotropic and z-downsampling is not desired at that stage. Its length should match the number of pooling stages (`len(feature_maps) - 1`). Defaults to `[2, 2, 2, 2]`. yx_down : List[int], optional A list of downsampling factors for the y and x dimensions at each pooling stage in the encoder. Its length should match the number of pooling stages (`len(feature_maps) - 1`). Defaults to `[2, 2, 2, 2]`. output_channels : List[int], optional Output channels of the network. If one value is provided, the model will have a single output head. If two values are provided, the model will have two output heads (e.g. for multi-task learning with instance segmentation and classification). separated_decoders : bool, optional Whether to use separated decoders for each output head. output_channel_info : list of str, optional Information about the type of output channels. Possible values are: - "X": where X is a letter, e.g. "F" for foreground, "D" for distance, "R" for rays, "C" for cpntours, etc. - "class": classification (e.g. for multi-task learning) explicit_activations : bool, optional If True, uses explicit activation functions in the last layers. head_activations : List[str], optional Activation functions to apply to each output head if `explicit_activations` is True. upsampling_factor : Tuple[int, ...], optional Factor of upsampling for super-resolution workflows. If provided, it dictates the kernel and stride for an initial or final transposed convolution. Defaults to an empty tuple `()`, meaning no super-resolution. upsampling_position : str, optional Determines where super-resolution upsampling is applied: - ``"pre"``: Upsampling is performed *before* the main U-NeXt model. - ``"post"``: Upsampling is performed *after* the main U-NeXt model. Defaults to "pre". stochastic_depth_prob : float, optional Maximum stochastic depth probability. This probability will progressively increase with each layer, reaching its maximum value at the bottleneck layer. Defaults to 0.1. layer_scale : float, optional Layer Scale parameter, used in ConvNeXt blocks. A small positive value (e.g., 1e-6) can stabilize training. Defaults to 1e-6. cn_layers : List[int] Number of ConvNeXt blocks repeated in each level (stage) of the encoder and bottleneck. This list should have the same length as `feature_maps`. Defaults to `[2, 2, 2, 2]`. isotropy : bool or List[bool], optional Controls whether to use 3D or 2D depthwise convolutions at each U-NeXt level when the input is 3D. - If `True` (bool), all levels use 3D depthwise convolutions. - If `False` (bool), all levels use 2D depthwise convolutions (1xKxK kernels for 3D input). - If `List[bool]`, specifies for each level whether to use 3D (True) or 2D (False) kernels. Defaults to True. stem_k_size : int, optional Size of the kernel for the initial stem layer's pooling/convolution. Defaults to 2. contrast : bool, optional Whether to add a contrastive learning projection head to the model. If True, an additional output `embed` will be available in the forward pass. Defaults to `False`. contrast_proj_dim : int, optional Dimension of the projection head for contrastive learning, if `contrast` is True. Defaults to `256`. return_one_tensor : bool, optional If True, concatenates all outputs into a single tensor along the channel dimension in the forward pass. Defaults to `False`. """ super(U_NeXt_V1, self).__init__() if len(output_channels) == 0: raise ValueError("'output_channels' needs to has at least one value") if contrast and len(output_channels) > 2: raise ValueError("If 'contrast' is True, 'output_channels' can only have two values at max: one for the main output and one for the class.") print("Selected output channels:") for i, info in enumerate(output_channel_info): print(f" - {i} channel for {info} output") self.depth = len(feature_maps) - 1 self.ndim = 3 if len(image_shape) == 4 else 2 self.z_down = z_down self.yx_down = yx_down self.output_channels = output_channels self.output_channel_info = output_channel_info self.return_class = True if "class" in output_channel_info else False self.explicit_activations = explicit_activations self.return_one_tensor = return_one_tensor if self.explicit_activations: assert len(head_activations) == sum(output_channels), "If 'explicit_activations' is True, 'head_activations' needs to " "have the same number of values as 'output_channels'" self.head_activations, self.class_head_activations = prepare_activation_layers(head_activations, output_channel_info, output_channels) if self.return_class and self.class_head_activations is None: raise ValueError("If 'return_class' is True, 'head_activations' must be provided.") layer_norm = nn.LayerNorm self.contrast = contrast # convert isotropy to list if it is a single bool if type(isotropy) == bool: isotropy = [isotropy] * len(feature_maps) if self.ndim == 3: conv = nn.Conv3d convtranspose = nn.ConvTranspose3d pre_ln_permutation = Permute([0, 2, 3, 4, 1]) post_ln_permutation = Permute([0, 4, 1, 2, 3]) dropout = nn.Dropout3d else: conv = nn.Conv2d convtranspose = nn.ConvTranspose2d pre_ln_permutation = Permute([0, 2, 3, 1]) post_ln_permutation = Permute([0, 3, 1, 2]) dropout = nn.Dropout2d # Super-resolution self.pre_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "pre": self.pre_upsampling = convtranspose( image_shape[-1], image_shape[-1], kernel_size=upsampling_factor, stride=upsampling_factor, ) self.down_path = nn.ModuleList() self.downsample_layers = nn.ModuleList() in_channels = image_shape[-1] # STEM z_factor = int(max(z_down[0] / stem_k_size, 1)) mpool = (stem_k_size * z_factor, stem_k_size, stem_k_size) if self.ndim == 3 else (stem_k_size, stem_k_size) self.down_path.append( nn.Sequential( conv(in_channels, feature_maps[0], kernel_size=mpool, stride=mpool), pre_ln_permutation, layer_norm(feature_maps[0]), post_ln_permutation, ) ) # depthwise kernel size for ConvNeXt block kernel_size = (7, 7) if self.ndim == 2 else (7, 7, 7) # Encoder stage_block_id = 0 total_stage_blocks = sum(cn_layers) sd_probs = [] for i in range(self.depth): stage = nn.ModuleList() sd_probs_stage = [] # adjust depthwise kernel size if needed if not isotropy[i] and self.ndim == 3: kernel_size = (1, 7, 7) # ConvNeXtBlocks for _ in range(cn_layers[i]): sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append( ConvNeXtBlock_V1( self.ndim, conv, feature_maps[i], layer_scale, sd_prob, layer_norm, k_size=kernel_size ) ) stage_block_id += 1 sd_probs_stage.append(sd_prob) self.down_path.append(nn.Sequential(*stage)) sd_probs.append(sd_probs_stage) # Downsampling mpool = (z_down[i], yx_down[i], yx_down[i]) if self.ndim == 3 else (yx_down[i], yx_down[i]) self.downsample_layers.append( nn.Sequential( pre_ln_permutation, layer_norm(feature_maps[i]), post_ln_permutation, conv( feature_maps[i], feature_maps[i + 1], kernel_size=mpool, stride=mpool, ), ) ) # BOTTLENECK stage = nn.ModuleList() if not isotropy[-1] and self.ndim == 3: kernel_size = (1, 7, 7) for _ in range(cn_layers[-1]): sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append( ConvNeXtBlock_V1( self.ndim, conv, feature_maps[-1], layer_scale, sd_prob, layer_norm, k_size=kernel_size ) ) stage_block_id += 1 self.bottleneck = nn.Sequential(*stage) # DECODER self.num_decoders = 1 if not separated_decoders else len(output_channels) self.up_paths = nn.ModuleList([nn.ModuleList() for _ in range(self.num_decoders)]) for j in range(self.num_decoders): in_channels = feature_maps[-1] for i in range(self.depth - 1, -1, -1): if not isotropy[i] and self.ndim == 3: kernel_size = (1, 7, 7) self.up_paths[j].append( UpConvNeXtBlock_V1( ndim=self.ndim, convtranspose=convtranspose, in_size=in_channels, out_size=feature_maps[i], z_down=z_down[i], yx_down=yx_down[i], up_mode=upsample_layer, conv=conv, attention_gate=False, cn_layers=cn_layers[i], sd_probs=sd_probs[i], layer_scale=layer_scale, layer_norm=layer_norm, k_size=kernel_size, ) ) # type: ignore in_channels = feature_maps[i] # Inverted Stem mpool = (stem_k_size * z_factor, stem_k_size, stem_k_size) if self.ndim == 3 else (stem_k_size, stem_k_size) self.up_paths[j].append( nn.Sequential( convtranspose(feature_maps[0], feature_maps[0], kernel_size=mpool, stride=mpool), pre_ln_permutation, layer_norm(feature_maps[0]), post_ln_permutation, ) ) # type: ignore # Super-resolution self.post_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "post": self.post_upsampling = convtranspose( feature_maps[0], feature_maps[0], kernel_size=upsampling_factor, stride=upsampling_factor, ) if self.contrast: # extra added layers self.heads = nn.Sequential( conv(feature_maps[0], feature_maps[0], kernel_size=3, stride=1, padding=1), layer_norm(feature_maps[0]), dropout(0.10), conv(feature_maps[0], output_channels[0], kernel_size=1, stride=1, padding=0, bias=False), ) self.proj_head = ProjectionHead(ndim=self.ndim, in_channels=feature_maps[0], proj_dim=contrast_proj_dim) else: self.heads = nn.Sequential() for i, out_ch in enumerate(output_channels): self.heads.append(conv(feature_maps[0], out_ch, kernel_size=1, padding="same")) init_weights(self)
[docs] def forward(self, x) -> Dict | torch.Tensor: """ Forward pass of the model. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels, height, width) for 2D or (batch_size, channels, depth, height, width) for 3D. Returns ------- Dict or torch.Tensor Model output. Returns a dictionary if multi-head or contrastive outputs are enabled, otherwise returns the main prediction tensor. """ # Super-resolution if self.pre_upsampling: x = self.pre_upsampling(x) # Encoder blocks = [] x = self.down_path[0](x) # (stem) for i, layers in enumerate(zip(self.down_path[1:], self.downsample_layers)): down, pool = layers x = down(x) blocks.append(x) x = pool(x) x_bot = self.bottleneck(x) # Decoder feats = [] for j in range(self.num_decoders): x = x_bot for i, up in enumerate(self.up_paths[j][:-1]): x = up(x, blocks[-i - 1]) x = self.up_paths[j][-1](x) feats.append(x) # Super-resolution if self.post_upsampling: feats[0] = self.post_upsampling(feats[0]) out_dict = {} # Pass the features through the output heads class_outs, outs = [], [] for i, head in enumerate(self.heads): feat = feats[i] if self.num_decoders > 1 else feats[0] if "class" in self.output_channel_info[i]: class_outs.append(head(feat)) else: outs.append(head(feat)) outs = torch.cat(outs, dim=1) # Apply activations to the output heads if explicit_activations is True if self.explicit_activations: # If there is only one activation, apply it to the whole tensor if len(self.head_activations) == 1: outs = self.head_activations[0](outs) else: for i, act in enumerate(self.head_activations): outs[:, i:i+1] = act(outs[:, i:i+1]) if self.return_class and self.class_head_activations is not None: for i, act in enumerate(self.class_head_activations): class_outs[i] = act(class_outs[i]) out_dict = { "pred": outs, } if self.return_class: out_dict["class"] = torch.cat(class_outs, dim=1) # Contrastive learning head if self.contrast: out_dict["embed"] = self.proj_head(feats[0]) if len(out_dict.keys()) == 1: return out_dict["pred"] else: if self.return_one_tensor: if "class" in out_dict: return torch.cat((out_dict["pred"], torch.argmax(out_dict["class"], dim=1).unsqueeze(1)), dim=1) else: return out_dict["pred"] return out_dict