Source code for biapy.models.unext_v2

"""
This module implements the U-NeXt (Version 2) architecture, a U-Net based model that incorporates the latest advancements from ConvNeXt V2 blocks.

It aims to combine the strong hierarchical feature learning of U-Nets with the improved
design principles of ConvNeXt V2, which are co-designed and scaled with Masked
Autoencoders for enhanced performance.

U-NeXt_V2 is designed for both 2D and 3D image segmentation tasks. It features
a ConvNeXt V2-style encoder and decoder, with specialized blocks for downsampling,
upsampling, and the bottleneck. It supports various configurations, including
optional super-resolution, multi-head outputs, and stochastic depth for regularization.

Classes:

- ``U_NeXt_V2``: The main U-NeXt model (Version 2).

This module relies on building blocks defined in `biapy.models.blocks`, such as
`UpConvNeXtBlock_V2`, `ConvNeXtBlock_V2`, and `ProjectionHead`.

References:

- `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28>`_
- `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders <https://openaccess.thecvf.com/content/CVPR2023/html/Woo_ConvNeXt_V2_Co-Designing_and_Scaling_ConvNets_With_Masked_Autoencoders_CVPR_2023_paper.html>`_.

Image representation:

.. image:: ../../img/models/unext.png
    :width: 100%
    :align: center

"""

import torch
import torch.nn as nn
from typing import Dict, List

from biapy.models.blocks import UpConvNeXtBlock_V2, ConvNeXtBlock_V2, prepare_activation_layers, init_weights
from torchvision.ops.misc import Permute
from biapy.models.heads import ProjectionHead


[docs] class U_NeXt_V2(nn.Module): """ Create 2D/3D U-NeXt V2 (U-Net based model with ConvNeXt V2 blocks). U-NeXt V2 combines the classic U-Net architecture with modern ConvNeXt V2 blocks, leveraging the co-design and scaling principles from Masked Autoencoders. This model aims to achieve high performance in biomedical image segmentation by integrating strong hierarchical feature learning with efficient and robust convolutional designs. Reference: `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28>`_, `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders <https://openaccess.thecvf.com/content/CVPR2023/html/Woo_ConvNeXt_V2_Co-Designing_and_Scaling_ConvNets_With_Masked_Autoencoders_CVPR_2023_paper.html>`_. """ def __init__( self, image_shape=(256, 256, 1), feature_maps=[32, 64, 128, 256], upsample_layer="convtranspose", z_down=[2, 2, 2, 2], yx_down=[2, 2, 2, 2], output_channels=[1], separated_decoders=False, output_channel_info=["F"], explicit_activations: bool = False, head_activations: List[str] = ["ce_sigmoid"], upsampling_factor=(), upsampling_position="pre", stochastic_depth_prob=0.1, cn_layers=[2, 2, 2, 2], isotropy=True, stem_k_size=2, contrast: bool = False, contrast_proj_dim: int = 256, return_one_tensor: bool = False, ): """ Initialize the U-NeXt_V2 model. Sets up the ConvNeXt V2-style encoder (downsampling path), decoder (upsampling path), stem, bottleneck, and optional super-resolution and multi-head output layers. It dynamically selects 2D or 3D convolutional and normalization layers based on `ndim` and `isotropy` settings. Stochastic depth probabilities are progressively increased across layers. Parameters ---------- image_shape : Tuple[int, ...] Dimensions of the input image. E.g., `(y, x, channels)` for 2D or `(z, y, x, channels)` for 3D. The last element `image_shape[-1]` should be the number of input channels. activation : str, optional Activation layer to be used throughout the model. (Note: ConvNeXt V2 blocks typically use GELU, this parameter might be less relevant for internal block activations but could apply to other parts if customized). feature_maps : List[int], optional A list specifying the number of feature maps (channels) at each level of the U-NeXt. The length of this list defines the depth of the network. Defaults to `[32, 64, 128, 256]`. upsample_layer : str, optional Type of layer to use for upsampling in the decoder path. Two options: "convtranspose" (using `nn.ConvTranspose2d`/`3d`) or "upsampling" (using `nn.Upsample` followed by convolution). Defaults to "convtranspose". z_down : List[int], optional For 3D data, a list of downsampling factors for the z-dimension at each pooling stage in the encoder. Set elements to `1` if the dataset is not isotropic and z-downsampling is not desired at that stage. Its length should match the number of pooling stages (`len(feature_maps) - 1`). Defaults to `[2, 2, 2, 2]`. yx_down : List[int], optional A list of downsampling factors for the y and x dimensions at each pooling stage in the encoder. Its length should match the number of pooling stages (`len(feature_maps) - 1`). Defaults to `[2, 2, 2, 2]`. output_channels : List[int], optional Output channels of the network. If one value is provided, the model will have a single output head. If two values are provided, the model will have two output heads (e.g. for multi-task learning with instance segmentation and classification). separated_decoders : bool, optional Whether to use separated decoders for each output head. output_channel_info : list of str, optional Information about the type of output channels. Possible values are: - "X": where X is a letter, e.g. "F" for foreground, "D" for distance, "R" for rays, "C" for cpntours, etc. - "class": classification (e.g. for multi-task learning) explicit_activations : bool, optional If True, uses explicit activation functions in the last layers. head_activations : List[str], optional Activation functions to apply to each output head if `explicit_activations` is True. upsampling_factor : Tuple[int, ...], optional Factor of upsampling for super-resolution workflows. If provided, it dictates the kernel and stride for an initial or final transposed convolution. Defaults to an empty tuple `()`, meaning no super-resolution. upsampling_position : str, optional Determines where super-resolution upsampling is applied: - ``"pre"``: Upsampling is performed *before* the main U-NeXt model. - ``"post"``: Upsampling is performed *after* the main U-NeXt model. Defaults to "pre". stochastic_depth_prob : float, optional Maximum stochastic depth probability. This probability will progressively increase with each layer, reaching its maximum value at the bottleneck layer. Defaults to 0.1. cn_layers : List[int] Number of ConvNeXt V2 blocks repeated in each level (stage) of the encoder and bottleneck. This list should have the same length as `feature_maps`. Defaults to `[2, 2, 2, 2]`. isotropy : bool or List[bool], optional Controls whether to use 3D or 2D depthwise convolutions at each U-NeXt level when the input is 3D. - If `True` (bool), all levels use 3D depthwise convolutions. - If `False` (bool), all levels use 2D depthwise convolutions (1xKxK kernels for 3D input). - If `List[bool]`, specifies for each level whether to use 3D (True) or 2D (False) kernels. Defaults to True. stem_k_size : int, optional Size of the kernel for the initial stem layer's pooling/convolution. Defaults to 2. contrast : bool, optional Whether to add a contrastive learning projection head to the model. If True, an additional output `embed` will be available in the forward pass. Defaults to `False`. contrast_proj_dim : int, optional Dimension of the projection head for contrastive learning, if `contrast` is True. Defaults to `256`. return_one_tensor : bool, optional If True, concatenates all outputs into a single tensor along the channel dimension in the forward pass. Defaults to `False`. """ super(U_NeXt_V2, self).__init__() if len(output_channels) == 0: raise ValueError("'output_channels' needs to has at least one value") if contrast and len(output_channels) > 2: raise ValueError("If 'contrast' is True, 'output_channels' can only have two values at max: one for the main output and one for the class.") print("Selected output channels:") for i, info in enumerate(output_channel_info): print(f" - {i} channel for {info} output") self.depth = len(feature_maps) - 1 self.ndim = 3 if len(image_shape) == 4 else 2 self.z_down = z_down self.yx_down = yx_down self.output_channels = output_channels self.output_channel_info = output_channel_info self.return_class = True if "class" in output_channel_info else False layer_norm = nn.LayerNorm self.contrast = contrast self.explicit_activations = explicit_activations self.return_one_tensor = return_one_tensor if self.explicit_activations: assert len(head_activations) == sum(output_channels), "If 'explicit_activations' is True, 'head_activations' needs to " "have the same number of values as 'output_channels'" self.head_activations, self.class_head_activations = prepare_activation_layers(head_activations, output_channel_info, output_channels) if self.return_class and self.class_head_activations is None: raise ValueError("If 'return_class' is True, 'head_activations' must be provided.") # convert isotropy to list if it is a single bool if type(isotropy) == bool: isotropy = [isotropy] * len(feature_maps) if self.ndim == 3: conv = nn.Conv3d convtranspose = nn.ConvTranspose3d pre_ln_permutation = Permute([0, 2, 3, 4, 1]) post_ln_permutation = Permute([0, 4, 1, 2, 3]) dropout = nn.Dropout3d else: conv = nn.Conv2d convtranspose = nn.ConvTranspose2d pre_ln_permutation = Permute([0, 2, 3, 1]) post_ln_permutation = Permute([0, 3, 1, 2]) dropout = nn.Dropout2d # Super-resolution self.pre_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "pre": self.pre_upsampling = convtranspose( image_shape[-1], image_shape[-1], kernel_size=upsampling_factor, stride=upsampling_factor, ) self.down_path = nn.ModuleList() self.downsample_layers = nn.ModuleList() in_channels = image_shape[-1] # STEM z_factor = int(max(z_down[0] / stem_k_size, 1)) mpool = (stem_k_size * z_factor, stem_k_size, stem_k_size) if self.ndim == 3 else (stem_k_size, stem_k_size) self.down_path.append( nn.Sequential( conv(in_channels, feature_maps[0], kernel_size=mpool, stride=mpool), pre_ln_permutation, layer_norm(feature_maps[0]), post_ln_permutation, ) ) # depthwise kernel size for ConvNeXt block kernel_size = (7, 7) if self.ndim == 2 else (7, 7, 7) # Encoder stage_block_id = 0 total_stage_blocks = sum(cn_layers) sd_probs = [] for i in range(self.depth): stage = nn.ModuleList() sd_probs_stage = [] # adjust depthwise kernel size if needed if not isotropy[i] and self.ndim == 3: kernel_size = (1, 7, 7) # ConvNeXtBlocks for _ in range(cn_layers[i]): sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append( ConvNeXtBlock_V2(self.ndim, conv, feature_maps[i], sd_prob, layer_norm, k_size=kernel_size) ) stage_block_id += 1 sd_probs_stage.append(sd_prob) self.down_path.append(nn.Sequential(*stage)) sd_probs.append(sd_probs_stage) # Downsampling mpool = (z_down[i], yx_down[i], yx_down[i]) if self.ndim == 3 else (yx_down[i], yx_down[i]) self.downsample_layers.append( nn.Sequential( pre_ln_permutation, layer_norm(feature_maps[i]), post_ln_permutation, conv( feature_maps[i], feature_maps[i + 1], kernel_size=mpool, stride=mpool, ), ) ) # BOTTLENECK stage = nn.ModuleList() if not isotropy[-1] and self.ndim == 3: kernel_size = (1, 7, 7) for _ in range(cn_layers[-1]): sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append(ConvNeXtBlock_V2(self.ndim, conv, feature_maps[-1], sd_prob, layer_norm, k_size=kernel_size)) stage_block_id += 1 self.bottleneck = nn.Sequential(*stage) # DECODER self.num_decoders = 1 if not separated_decoders else len(output_channels) self.up_paths = nn.ModuleList([nn.ModuleList() for _ in range(self.num_decoders)]) for j in range(self.num_decoders): in_channels = feature_maps[-1] for i in range(self.depth - 1, -1, -1): if not isotropy[i] and self.ndim == 3: kernel_size = (1, 7, 7) self.up_paths[j].append( UpConvNeXtBlock_V2( ndim=self.ndim, convtranspose=convtranspose, in_size=in_channels, out_size=feature_maps[i], z_down=z_down[i], yx_down=yx_down[i], up_mode=upsample_layer, conv=conv, attention_gate=False, cn_layers=cn_layers[i], sd_probs=sd_probs[i], layer_norm=layer_norm, k_size=kernel_size, ) # type: ignore ) in_channels = feature_maps[i] # Inverted Stem mpool = (stem_k_size * z_factor, stem_k_size, stem_k_size) if self.ndim == 3 else (stem_k_size, stem_k_size) self.up_paths[j].append( nn.Sequential( convtranspose(feature_maps[0], feature_maps[0], kernel_size=mpool, stride=mpool), pre_ln_permutation, layer_norm(feature_maps[0]), post_ln_permutation, ) # type: ignore ) # Super-resolution self.post_upsampling = None if len(upsampling_factor) > 0 and upsampling_position == "post": self.post_upsampling = convtranspose( feature_maps[0], feature_maps[0], kernel_size=upsampling_factor, stride=upsampling_factor, ) if self.contrast: # extra added layers self.heads = nn.Sequential( conv(feature_maps[0], feature_maps[0], kernel_size=3, stride=1, padding=1), layer_norm(feature_maps[0]), dropout(0.10), conv(feature_maps[0], output_channels[0], kernel_size=1, stride=1, padding=0, bias=False), ) self.proj_head = ProjectionHead(ndim=self.ndim, in_channels=feature_maps[0], proj_dim=contrast_proj_dim) else: self.heads = nn.Sequential() for i, out_ch in enumerate(output_channels): self.heads.append(conv(feature_maps[0], out_ch, kernel_size=1, padding="same")) init_weights(self)
[docs] def forward(self, x) -> Dict | torch.Tensor: """ Forward pass of the model. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels, height, width) for 2D or (batch_size, channels, depth, height, width) for 3D. Returns ------- Dict or torch.Tensor Model output. Returns a dictionary if multi-head or contrastive outputs are enabled, otherwise returns the main prediction tensor. """ # Super-resolution if self.pre_upsampling: x = self.pre_upsampling(x) # Encoder blocks = [] x = self.down_path[0](x) # (stem) for i, layers in enumerate(zip(self.down_path[1:], self.downsample_layers)): down, pool = layers x = down(x) blocks.append(x) x = pool(x) x_bot = self.bottleneck(x) # Decoder feats = [] for j in range(self.num_decoders): x = x_bot for i, up in enumerate(self.up_paths[j][:-1]): x = up(x, blocks[-i - 1]) x = self.up_paths[j][-1](x) feats.append(x) # Super-resolution if self.post_upsampling: feats[0] = self.post_upsampling(feats[0]) out_dict = {} # Pass the features through the output heads class_outs, outs = [], [] for i, head in enumerate(self.heads): feat = feats[i] if self.num_decoders > 1 else feats[0] if "class" in self.output_channel_info[i]: class_outs.append(head(feat)) else: outs.append(head(feat)) outs = torch.cat(outs, dim=1) # Apply activations to the output heads if explicit_activations is True if self.explicit_activations: # If there is only one activation, apply it to the whole tensor if len(self.head_activations) == 1: outs = self.head_activations[0](outs) else: for i, act in enumerate(self.head_activations): outs[:, i:i+1] = act(outs[:, i:i+1]) if self.return_class and self.class_head_activations is not None: for i, act in enumerate(self.class_head_activations): class_outs[i] = act(class_outs[i]) out_dict = { "pred": outs, } if self.return_class: out_dict["class"] = torch.cat(class_outs, dim=1) # Contrastive learning head if self.contrast: out_dict["embed"] = self.proj_head(feats[0]) if len(out_dict.keys()) == 1: return out_dict["pred"] else: if self.return_one_tensor: if "class" in out_dict: return torch.cat((out_dict["pred"], torch.argmax(out_dict["class"], dim=1).unsqueeze(1)), dim=1) else: return out_dict["pred"] return out_dict