UNETR

class biapy.models.unetr.UNETR(input_shape, patch_size, embed_dim, depth, num_heads, mlp_ratio=4.0, num_filters=16, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, n_classes=1, decoder_activation='relu', ViT_hidd_mult=3, batch_norm=True, dropout=0.0, k_size=3, output_channels='BC')[source]

Bases: Module

UNETR architecture. It combines a ViT with U-Net, replaces the convolutional encoder with the ViT and adapt each skip connection signal to their layer’s spatial dimensionality.

Reference: UNETR: Transformers for 3D Medical Image Segmentation.

Parameters:
  • input_shape (3D/4D tuple) – Dimensions of the input image. E.g. (y, x, channels) or (z, y, x, channels).

  • patch_size (int) – Size of the patches that are extracted from the input image. As an example, to use 16x16 patches, set patch_size = 16.

  • embed_dim (int) – Dimension of the embedding space.

  • depth (int) – Number of transformer encoder layers.

  • num_heads (int) – Number of heads in the multi-head attention layer.

  • mlp_ratio (float, optional) – Ratio to multiply embed_dim to obtain the dense layers of the final classifier.

  • num_filters (int, optional) – Number of filters in the first UNETR’s layer of the decoder. In each layer the previous number of filters is doubled.

  • norm_layer (Torch layer, optional) – Nomarlization layer to use in ViT backbone.

  • n_classes (int, optional) – Number of classes to predict. Is the number of channels in the output tensor.

  • decoder_activation (str, optional) – Activation function for the decoder.

  • ViT_hidd_mult (int, optional) – Multiple of the transformer encoder layers from of which the skip connection signal is going to be extracted. E.g. if we have 12 transformer encoder layers, and we set ViT_hidd_mult = 3, we are going to take [1*ViT_hidd_mult, 2*ViT_hidd_mult, 3*ViT_hidd_mult] -> [Z3, Z6, Z9] encoder’s signals.

  • batch_norm (bool, optional) – Whether to use batch normalization or not.

  • dropout (bool, optional) – Dropout rate for the decoder (can be a list of dropout rates for each layer).

  • k_size (int, optional) – Decoder convolutions’ kernel size.

  • output_channels (str, optional) – Channels to operate with. Possible values: BC, BCD, BP, BCDv2, BDv2, Dv2 and BCM.

Returns:

model – UNETR model.

Return type:

Torch model

proj_feat(x)[source]
forward(input) Tensor | List[Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.