Attention U-Net
- class biapy.models.attention_unet.Attention_U_Net(image_shape=(256, 256, 1), activation='ELU', feature_maps=[32, 64, 128, 256], drop_values=[0.1, 0.1, 0.1, 0.1], batch_norm=False, k_size=3, upsample_layer='convtranspose', z_down=[2, 2, 2, 2], n_classes=1, output_channels='BC', upsampling_factor=(), upsampling_position='pre')[source]
Bases:
Module
Create 2D/3D U-Net with Attention blocks.
Reference: Attention U-Net: Learning Where to Look for the Pancreas.
- Parameters:
image_shape (3D/4D tuple) – Dimensions of the input image. E.g.
(y, x, channels)
or(z, y, x, channels)
.activation (str, optional) – Activation layer.
feature_maps (array of ints, optional) – Feature maps to use on each level.
drop_values (float, optional) – Dropout value to be fixed.
batch_norm (bool, optional) – Make batch normalization.
k_size (int, optional) – Kernel size.
upsample_layer (str, optional) – Type of layer to use to make upsampling. Two options: “convtranspose” or “upsampling”.
z_down (List of ints, optional) – Downsampling used in z dimension. Set it to
1
if the dataset is not isotropic.n_classes (int, optional) – Number of classes.
output_channels (str, optional) – Channels to operate with. Possible values:
BC
,BCD
,BP
,BCDv2
,BDv2
,Dv2
andBCM
.upsampling_factor (tuple of ints, optional) – Factor of upsampling for super resolution workflow for each dimension.
upsampling_position (str, optional) – Whether the upsampling is going to be made previously (
pre
option) to the model or after the model (post
option).
- Returns:
model – Attention U-Net model.
- Return type:
Torch model
Calling this function with its default parameters returns the following network:
Image created with PlotNeuralNet.
That networks incorporates in skip connecions Attention Gates (AG), which can be seen as follows:
Image extracted from Attention U-Net: Learning Where to Look for the Pancreas.
- forward(x) Tensor | List[Tensor] [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.