MAE

class biapy.models.mae.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=3, ndim=2, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, norm_pix_loss=False, masking_type='random', mask_ratio=0.5, device=None)[source]

Bases: Module

Mask autoenconder (MAE) with VisionTransformer (ViT) backbone.

Reference: Masked Autoencoders Are Scalable Vision Learners.

Parameters:
  • img_size (int, optional) – Size of the input image.

  • patch_size (int, optional) – Size of the input size or token size for the transformer.

  • in_chans (int, optional) – Number of channels.

  • ndim (int, optional) – Number of input dimensions.

  • embed_dim (int, optional) – Size of the transformer embedding.

  • depth (int, optional) – Number of layers of the transformer.

  • num_heads (int, optional) – Number of heads in the multi-head attention layer.

  • mlp_ratio (float, optional) – Size of the dense layers of the final classifier. This value will mutiply embed_dim.

  • decoder_embed_dim (int, optional) – Size of the transformer embedding in the decoder.

  • decoder_depth (int, optional) – Number of layers of the decoder.

  • decoder_num_heads (int, optional) – Number of heads in the multi-head attention layer in the decoder.

  • norm_layer (Torch layer, optional) – Normalization layer to use in the model.

  • norm_pix_loss (bool, optional) – Use (per-patch) normalized pixels as targets for computing loss

  • mask_ratio (float, optional) – Percentage of the input image to mask. Value between 0 and 1.

  • device (Torch device) – Device used.

Returns:

model – MAE model.

Return type:

Torch model

initialize_weights()[source]

Initialize layer weigths.

patchify(imgs)[source]

Create patches from input image. Opposite function of unpatchify().

Parameters:

imgs (Tensor) – Input images. In 2D: (N, C, H, W), in 3D: (N, C, Z, H, W). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

Returns:

x – MAE model. in 2D: (N, L, patch_size**2 *C) in 3D: (N, L, patch_size**3 *C). Where N is the batch size, L is the multiplication of dimension (i.e. Z, H and W) and C are the channels.

Return type:

Torch tensor

unpatchify(x)[source]

Create original image shape from input patches. Opposite function of patchify().

Parameters:

x (Tensor) – Input images. In 2D: (N, L, patch_size**2 *C), in 3D: (N, L, patch_size**3 *C). Where N is the batch size, L is the multiplication of dimension (i.e. Z, H and W) and C are the channels.

Returns:

imgs – MAE model. in 2D: (N, C, H, W) in 3D: (N, C, Z, H, W). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

Return type:

Torch tensor

random_masking(x)[source]

Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise.

Parameters:

x (Tensor) – Input images. Is shape is (N, L, D) shape. Where N is the batch size, L is the multiplication of dimension (i.e. Z, H and W) and D is embed_dim.

grid_masking(x)[source]

Perform grid masking for each sample.

Parameters:

x (Tensor) – Input images. Is shape is (N, L, D) shape. Where N is the batch size, L is the multiplication of dimension (i.e. Z, H and W) and D is embed_dim.

forward_encoder(x)[source]

Encoder forward pass.

forward_decoder(x, ids_restore)[source]

Decoder forward pass.

forward_loss(imgs, pred, mask)[source]

MAE loss calculation.

Parameters:
  • imgs (Tensor) – Input images. In 2D: (N, C, H, W), in 3D: (N, C, Z, H, W). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

  • pred (Tensor) – Predictions. In 2D: (N, L, patch_size**2 *C), in 3D: (N, L, patch_size**3 *C). Where N is the batch size, L is the multiplication of dimension (i.e. Z, H and W) and C are the channels.

  • mask (2d array) – Information of which patches will be retain and masked. Shape is: (N, L) where 0 is keep and 1 is remove.

Returns:

loss – Calculated loss on masked patches only.

Return type:

Tensor

forward(imgs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

save_images(_x, _y, _mask, dtype)[source]

Save images from MAE.

Parameters:
  • _x (Torch tensor) – Input images. In 2D: (N, C, H, W), in 3D: (N, C, Z, H, W). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

  • _y (Torch tensor) – Input images. In 2D: (N, L, patch_size**2 *C), in 3D: (N, L, patch_size**3 *C). Where N is the batch size, L is the multiplication of dimension (i.e. Z, H and W) and C are the channels.

  • _mask (2d array) – Information of which patches will be retain and masked. Shape is: (N, L) where 0 is keep and 1 is remove.

  • dtype (Numpy dtype) – Dtype to save the images.

Returns:

  • pred (4D/5D Numpy array) – Predicted images converted to Numpy. In 2D: (N, H, W, C), in 3D: (N, Z, H, W, C). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

  • p_mask (4D/5D Numpy array) – Predicted images’s mask. In 2D: (N, H, W, C), in 3D: (N, Z, H, W, C). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

  • pred_visi (4D/5D Numpy array) – Predicted image with visible patches. In 2D: (N, H, W, C), in 3D: (N, Z, H, W, C). Where N is the batch size, C are the channels, Z image depth, H image height and W image’s width.

biapy.models.mae.mae_vit_base_patch16_dec512d8b(**kwargs)[source]
biapy.models.mae.mae_vit_large_patch16_dec512d8b(**kwargs)[source]
biapy.models.mae.mae_vit_huge_patch14_dec512d8b(**kwargs)[source]
biapy.models.mae.mae_vit_base_patch16(**kwargs)
biapy.models.mae.mae_vit_large_patch16(**kwargs)
biapy.models.mae.mae_vit_huge_patch14(**kwargs)