biapy.models.maeο
This file implements the Masked Autoencoder (MAE) model with a Vision Transformer (ViT) backbone, as described in the paper βMasked Autoencoders Are Scalable Vision Learnersβ (https://openaccess.thecvf.com/content/CVPR2022/html/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper).
The MAE model is designed for self-supervised pre-training of Vision Transformers by reconstructing masked-out patches of an image. It consists of an encoder that processes visible patches and a lightweight decoder that reconstructs the original image from the encoderβs latent representation and mask tokens.
Key components and functionalities include:
Classes:
MaskedAutoencoderViT: The main MAE model, encompassing the encoder and decoder.
Functions:
mae_vit_base_patch16_dec512d8b: Factory function for a base MAE-ViT model.mae_vit_large_patch16_dec512d8b: Factory function for a large MAE-ViT model.mae_vit_huge_patch14_dec512d8b: Factory function for a huge MAE-ViT model.
The implementation supports both 2D and 3D image inputs, different masking strategies (random and grid), and provides methods for patching/unpatching images, forward passes through encoder/decoder, and loss calculation.
References:
Masked Autoencoders Are Scalable Vision Learners: https://openaccess.thecvf.com/content/CVPR2022/html/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper
timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
- class biapy.models.mae.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=3, ndim=2, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, norm_pix_loss=False, masking_type='random', mask_ratio=0.5, return_just_preds=False, device='cpu')[source]ο
Bases:
ModuleMasked Autoencoder (MAE) with Vision Transformer (ViT) backbone.
This model implements the architecture proposed in βMasked Autoencoders Are Scalable Vision Learnersβ for self-supervised pre-training by reconstructing masked image patches. It comprises an encoder to process unmasked patches and a decoder to reconstruct the full image, including the masked regions.
Reference: Masked Autoencoders Are Scalable Vision Learners.
- Parameters:
img_size (int, optional) β Size of the input image (height and width for 2D, or depth, height, and width for 3D, assuming square/cubic dimensions). Defaults to 224.
patch_size (int, optional) β Size of the square/cubic patch (token) that the image is divided into. Defaults to 16.
in_chans (int, optional) β Number of input image channels (e.g., 3 for RGB, 1 for grayscale). Defaults to 3.
ndim (int, optional) β Number of input dimensions, 2 for 2D images (H, W) or 3 for 3D images (D, H, W). Defaults to 2.
embed_dim (int, optional) β Dimensionality of the embedding space for the Vision Transformer encoder. Defaults to 1024.
depth (int, optional) β Number of transformer encoder blocks (layers). Defaults to 24.
num_heads (int, optional) β Number of attention heads in the multi-head attention layer of the encoder. Defaults to 16.
mlp_ratio (float, optional) β Ratio of the hidden dimension of the MLP block to the embed_dim. Defaults to 4.0.
decoder_embed_dim (int, optional) β Dimensionality of the embedding space for the MAE decoder. Defaults to 512.
decoder_depth (int, optional) β Number of transformer decoder blocks (layers). Defaults to 8.
decoder_num_heads (int, optional) β Number of attention heads in the multi-head attention layer of the decoder. Defaults to 16.
norm_layer (Torch layer, optional) β Normalization layer to use throughout the model (e.g., nn.LayerNorm). Defaults to nn.LayerNorm.
norm_pix_loss (bool, optional) β If True, normalize pixel values (mean 0, variance 1) per patch before computing the reconstruction loss. This helps stabilize training. Defaults to False.
masking_type (str, optional) β Type of masking strategy to apply. Can be βrandomβ for random patch masking or βgridβ for structured grid masking. Defaults to βrandomβ.
mask_ratio (float, optional) β Percentage of the input image patches to mask out. Value between 0 and 1. Only applicable when masking_type is βrandomβ. Defaults to 0.5.
device (str, optional) β The device (e.g., βcudaβ, βcpuβ) where the model parameters and input tensors will be stored.
- Returns:
model β The MAE model.
- Return type:
nn.Module
- initialize_weights()[source]ο
Initialize the weights of the modelβs layers.
This method applies specific initialization strategies to different types of layers within the MAE model, including: - Truncated normal initialization for positional embeddings (pos_embed, decoder_pos_embed). - Xavier uniform initialization for the patch embedding projection (patch_embed.proj.weight). - Normal initialization for the class token (cls_token) and mask token (mask_token). - Calls _init_weights to initialize nn.Linear and nn.LayerNorm layers.
- patchify(imgs)[source]ο
Convert an input image into a sequence of non-overlapping patches.
This function is the inverse of unpatchify. It rearranges the pixel data from a standard image tensor format into a sequence of flattened patch vectors.
- Parameters:
imgs (Tensor) β Input images.
For 2D: (N, C, H, W), where N is batch size, C are channels, H is height, and W is width.
For 3D: (N, C, Z, H, W), where N is batch size, C are channels, Z is depth, H is height, and W is width.
- Returns:
x β Flattened image patches.
For 2D: (N, L, patch_size**2 * C), where L is the total number of patches ((H*W)/(p*p)).
For 3D: (N, L, patch_size**3 * C), where L is the total number of patches ((Z*H*W)/(p*p*p)).
- Return type:
Torch tensor
- unpatchify(x)[source]ο
Reconstruct an image from a sequence of flattened patches.
This function is the inverse of patchify. It takes a batch of flattened patches and reshapes them back into standard image tensor format.
- Parameters:
x (Tensor) β Input patches.
For 2D: (N, L, patch_size**2 * C), where N is batch size, L is the number of patches, and C are channels.
For 3D: (N, L, patch_size**3 * C), where N is batch size, L is the number of patches, and C are channels.
- Returns:
imgs β Reconstructed images.
For 2D: (N, C, H, W).
For 3D: (N, C, Z, H, W).
- Return type:
Torch tensor
- random_masking(x)[source]ο
Perform per-sample random masking of input patches.
This method randomly selects a subset of patches to keep (visible) and masks out the rest. The selection is done by shuffling patch indices based on random noise.
- Parameters:
x (Tensor) β Input patches with shape (N, L, D), where N is the batch size, L is the number of patches, and D is the embedding dimension.
- Returns:
x_masked (Tensor) β The input patches with masked patches removed, shape (N, L_keep, D).
mask (Tensor) β A binary mask tensor of shape (N, L), where 0 indicates a kept (visible) patch and 1 indicates a removed (masked) patch.
ids_restore (Tensor) β Indices to restore the original order of patches, shape (N, L).
- grid_masking(x)[source]ο
Perform grid-based masking for input patches.
This method applies a pre-defined checkerboard-like grid mask to the input patches, ensuring a structured masking pattern.
- Parameters:
x (Tensor) β Input patches with shape (N, L, D), where N is the batch size, L is the number of patches, and D is the embedding dimension.
- Returns:
x_masked (Tensor) β The input patches with masked patches removed based on the grid pattern, shape (N, L_keep, D).
mask (Tensor) β A binary mask tensor of shape (N, L), where 0 indicates a kept (visible) patch and 1 indicates a removed (masked) patch.
ids_restore (Tensor) β Indices to restore the original order of patches, shape (N, L).
- forward_encoder(x)[source]ο
Perform the forward pass through the MAE encoder.
This method first embeds the input image into patches, adds positional embeddings, applies masking, appends the class token, and then processes the resulting sequence through a series of Transformer encoder blocks.
- Parameters:
x (Tensor) β Input image tensor. Its shape depends on ndim:
For 2D: (N, C, H, W)
For 3D: (N, C, Z, H, W)
- Returns:
latent (Tensor) β The latent representation produced by the encoder, typically ` (N, L_keep + 1, embed_dim)` where L_keep is the number of visible patches.
mask (Tensor) β A binary mask indicating which patches were kept (0) or removed (1), shape (N, L).
ids_restore (Tensor) β Indices to restore the original patch order, shape (N, L).
- forward_decoder(x, ids_restore)[source]ο
Perform the forward pass through the MAE decoder.
The decoder takes the encoderβs latent representation, appends mask tokens, restores the original patch order, adds decoder positional embeddings, and then processes the sequence through a series of Transformer decoder blocks to predict the original pixel values of all patches.
- Parameters:
x (Tensor) β Latent representation from the encoder, shape (N, L_keep + 1, embed_dim).
ids_restore (Tensor) β Indices to restore the original patch order, shape (N, L).
- Returns:
x β The reconstructed patches, shape (N, L, patch_size**ndim * in_chans).
- Return type:
Tensor
- forward_loss(imgs, pred, mask)[source]ο
Calculate the MAE reconstruction loss.
The loss is computed only on the masked patches. Optionally, pixel values can be normalized per patch before loss calculation.
- Parameters:
imgs (Tensor) β Original input images. - For 2D: (N, C, H, W). - For 3D: (N, C, Z, H, W).
pred (Tensor) β Predicted patches from the decoder, shape (N, L, patch_size**ndim * C).
mask (Tensor) β A binary mask indicating which patches were masked (1) or visible (0), shape (N, L).
- Returns:
loss β The calculated mean squared error (MSE) loss, averaged only over the masked patches.
- Return type:
Tensor
- forward(imgs, return_just_preds=False) dict | Tensor[source]ο
Perform the complete forward pass of the Masked Autoencoder.
This method orchestrates the full MAE process: encoding visible patches, decoding the full image, and calculating the reconstruction loss.
- Parameters:
imgs (Tensor) β Input images.
For 2D: (N, C, H, W).
For 3D: (N, C, Z, H, W).
- Returns:
A dictionary containing:
βlossβ: The calculated reconstruction loss (Tensor).
- βpredβ: The predicted full patch sequence from the decoder (Tensor),
shape (N, L, patch_size**ndim * C).
βmaskβ: The binary mask used during masking (Tensor), shape (N, L).
- Return type:
dict
- save_images(_x, _y, _mask, dtype)[source]ο
Generate and prepare images for visualization/saving from MAE outputs.
This method reconstructs the predicted image, creates a masked version of the original input, and generates an image where visible patches from the original are combined with reconstructed masked patches.
- Parameters:
_x (Torch tensor) β Original input images. - For 2D: (N, C, H, W). - For 3D: (N, C, Z, H, W).
_y (Torch tensor) β MAE modelβs predicted patches, shape (N, L, patch_size**ndim * C).
_mask (Torch tensor) β Binary mask indicating masked (1) and visible (0) patches, shape (N, L).
dtype (Numpy dtype) β The desired NumPy data type for the output images.
- Returns:
pred (4D/5D Numpy array) β The fully reconstructed images (from decoder predictions), converted to NumPy. - For 2D: (N, H, W, C). - For 3D: (N, Z, H, W, C).
p_mask (4D/5D Numpy array) β The original input images with only the visible (unmasked) patches remaining, converted to NumPy. - For 2D: (N, H, W, C). - For 3D: (N, Z, H, W, C).
pred_visi (4D/5D Numpy array) β The image where the visible patches are from the original input, and the masked regions are filled with the decoderβs predictions, converted to NumPy. - For 2D: (N, H, W, C). - For 3D: (N, Z, H, W, C).
- biapy.models.mae.mae_vit_base_patch16_dec512d8b(**kwargs)[source]ο
Create a Masked Autoencoder ViT (MAE-ViT) model with a Base-sized encoder and a decoder with 512 embedding dimensions and 8 blocks.
This function serves as a convenient constructor for a specific MAE-ViT configuration, often used as a standard baseline.
- Parameters:
**kwargs β Arbitrary keyword arguments to be passed to the MaskedAutoencoderViT constructor. This allows overriding default parameters like img_size, in_chans, norm_pix_loss, masking_type, mask_ratio, or device.
- Returns:
model β An initialized MAE-ViT model configured as a base variant.
- Return type:
- biapy.models.mae.mae_vit_large_patch16_dec512d8b(**kwargs)[source]ο
Create a Masked Autoencoder ViT (MAE-ViT) model with a Large-sized encoder and a decoder with 512 embedding dimensions and 8 blocks.
This function provides a constructor for a larger MAE-ViT configuration, suitable for tasks requiring more capacity.
- Parameters:
**kwargs β Arbitrary keyword arguments to be passed to the MaskedAutoencoderViT constructor. This allows overriding default parameters like img_size, in_chans, norm_pix_loss, masking_type, mask_ratio, or device.
- Returns:
model β An initialized MAE-ViT model configured as a large variant.
- Return type:
- biapy.models.mae.mae_vit_huge_patch14_dec512d8b(**kwargs)[source]ο
Create a Masked Autoencoder ViT (MAE-ViT) model with a Huge-sized encoder and a decoder with 512 embedding dimensions and 8 blocks.
This function provides a constructor for the largest MAE-ViT configuration, designed for tasks demanding maximum model capacity.
- Parameters:
**kwargs β Arbitrary keyword arguments to be passed to the MaskedAutoencoderViT constructor. This allows overriding default parameters like img_size, in_chans, norm_pix_loss, masking_type, mask_ratio, or device.
- Returns:
model β An initialized MAE-ViT model configured as a huge variant.
- Return type:
- biapy.models.mae.mae_vit_base_patch16(**kwargs)ο
Create a Masked Autoencoder ViT (MAE-ViT) model with a Base-sized encoder and a decoder with 512 embedding dimensions and 8 blocks.
This function serves as a convenient constructor for a specific MAE-ViT configuration, often used as a standard baseline.
- Parameters:
**kwargs β Arbitrary keyword arguments to be passed to the MaskedAutoencoderViT constructor. This allows overriding default parameters like img_size, in_chans, norm_pix_loss, masking_type, mask_ratio, or device.
- Returns:
model β An initialized MAE-ViT model configured as a base variant.
- Return type:
- biapy.models.mae.mae_vit_large_patch16(**kwargs)ο
Create a Masked Autoencoder ViT (MAE-ViT) model with a Large-sized encoder and a decoder with 512 embedding dimensions and 8 blocks.
This function provides a constructor for a larger MAE-ViT configuration, suitable for tasks requiring more capacity.
- Parameters:
**kwargs β Arbitrary keyword arguments to be passed to the MaskedAutoencoderViT constructor. This allows overriding default parameters like img_size, in_chans, norm_pix_loss, masking_type, mask_ratio, or device.
- Returns:
model β An initialized MAE-ViT model configured as a large variant.
- Return type:
- biapy.models.mae.mae_vit_huge_patch14(**kwargs)ο
Create a Masked Autoencoder ViT (MAE-ViT) model with a Huge-sized encoder and a decoder with 512 embedding dimensions and 8 blocks.
This function provides a constructor for the largest MAE-ViT configuration, designed for tasks demanding maximum model capacity.
- Parameters:
**kwargs β Arbitrary keyword arguments to be passed to the MaskedAutoencoderViT constructor. This allows overriding default parameters like img_size, in_chans, norm_pix_loss, masking_type, mask_ratio, or device.
- Returns:
model β An initialized MAE-ViT model configured as a huge variant.
- Return type: