biapy.models.unetr
This module implements the UNETR (U-Net TRansformers) architecture, a hybrid deep learning model that combines the strengths of Vision Transformers (ViT) with the U-Net’s skip-connection mechanism.
UNETR replaces the traditional convolutional encoder of a U-Net with a ViT, allowing it to capture long-range dependencies effectively. The ViT’s latent representations are then integrated into a convolutional decoder via skip connections, adapting their spatial dimensionality to match the decoder’s levels. This design is particularly well-suited for 3D medical image segmentation.
Classes:
UNETR: The main UNETR model, integrating a ViT encoder with a U-Net-like decoder.
This module leverages components from biapy.models.blocks such as DoubleConvBlock, ConvBlock, ProjectionHead, and normalization helpers (get_norm_2d, get_norm_3d), as well as PatchEmbed from biapy.models.tr_layers.
Reference: UNETR: Transformers for 3D Medical Image Segmentation.
- class biapy.models.unetr.UNETR(input_shape, patch_size, embed_dim, depth, num_heads, mlp_ratio=4.0, num_filters=16, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, output_channels=[1], output_channel_info=['F'], explicit_activations: bool = False, head_activations: ~typing.List[str] = ['ce_sigmoid'], decoder_activation='relu', ViT_hidd_mult=3, normalization='bn', dropout=0.0, k_size=3, contrast: bool = False, contrast_proj_dim: int = 256, return_one_tensor: bool = False)[source]
Bases:
ModuleUNETR (U-Net TRansformers) architecture.
This model combines a Vision Transformer (ViT) as an encoder with a U-Net-like convolutional decoder. The ViT processes input images as sequences of patches, capturing global context, while the decoder reconstructs the output by upsampling and integrating skip connections from the ViT’s intermediate layers.
Reference: UNETR: Transformers for 3D Medical Image Segmentation.
- proj_feat(x)[source]
Reshape and permute the flattened ViT feature tensor back into a spatial feature map format suitable for convolutional layers.
- Parameters:
x (torch.Tensor) – Flattened feature tensor from the ViT encoder, typically ` (batch_size, num_patches, embed_dim)`.
- Returns:
Reshaped and permuted feature tensor, e.g., (batch_size, embed_dim, D, H, W) for 3D or (batch_size, embed_dim, H, W) for 2D.
- Return type:
torch.Tensor
- forward(input) Dict | Tensor[source]
Forward pass of the model.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, channels, height, width) for 2D or (batch_size, channels, depth, height, width) for 3D.
- Returns:
Model output. Returns a dictionary if multi-head or contrastive outputs are enabled, otherwise returns the main prediction tensor.
- Return type:
Dict or torch.Tensor