UNETR
- class biapy.models.unetr.UNETR(input_shape, patch_size, embed_dim, depth, num_heads, mlp_ratio=4.0, num_filters=16, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, n_classes=1, decoder_activation='relu', ViT_hidd_mult=3, batch_norm=True, dropout=0.0, k_size=3, output_channels='BC')[source]
Bases:
Module
UNETR architecture. It combines a ViT with U-Net, replaces the convolutional encoder with the ViT and adapt each skip connection signal to their layer’s spatial dimensionality.
Reference: UNETR: Transformers for 3D Medical Image Segmentation.
- Parameters:
input_shape (3D/4D tuple) – Dimensions of the input image. E.g.
(y, x, channels)
or(z, y, x, channels)
.patch_size (int) – Size of the patches that are extracted from the input image. As an example, to use
16x16
patches, setpatch_size = 16
.embed_dim (int) – Dimension of the embedding space.
depth (int) – Number of transformer encoder layers.
num_heads (int) – Number of heads in the multi-head attention layer.
mlp_ratio (float, optional) – Ratio to multiply
embed_dim
to obtain the dense layers of the final classifier.num_filters (int, optional) – Number of filters in the first UNETR’s layer of the decoder. In each layer the previous number of filters is doubled.
norm_layer (Torch layer, optional) – Nomarlization layer to use in ViT backbone.
n_classes (int, optional) – Number of classes to predict. Is the number of channels in the output tensor.
decoder_activation (str, optional) – Activation function for the decoder.
ViT_hidd_mult (int, optional) – Multiple of the transformer encoder layers from of which the skip connection signal is going to be extracted. E.g. if we have
12
transformer encoder layers, and we setViT_hidd_mult = 3
, we are going to take[1*ViT_hidd_mult, 2*ViT_hidd_mult, 3*ViT_hidd_mult]
->[Z3, Z6, Z9]
encoder’s signals.batch_norm (bool, optional) – Whether to use batch normalization or not.
dropout (bool, optional) – Dropout rate for the decoder (can be a list of dropout rates for each layer).
k_size (int, optional) – Decoder convolutions’ kernel size.
output_channels (str, optional) – Channels to operate with. Possible values:
BC
,BCD
,BP
,BCDv2
,BDv2
,Dv2
andBCM
.
- Returns:
model – UNETR model.
- Return type:
Torch model
- forward(input) Tensor | List[Tensor] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.