MAE
- class biapy.models.mae.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=3, ndim=2, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, norm_pix_loss=False, masking_type='random', mask_ratio=0.5, device=None)[source]
Bases:
Module
Mask autoenconder (MAE) with VisionTransformer (ViT) backbone.
Reference: Masked Autoencoders Are Scalable Vision Learners.
- Parameters:
img_size (int, optional) – Size of the input image.
patch_size (int, optional) – Size of the input size or token size for the transformer.
in_chans (int, optional) – Number of channels.
ndim (int, optional) – Number of input dimensions.
embed_dim (int, optional) – Size of the transformer embedding.
depth (int, optional) – Number of layers of the transformer.
num_heads (int, optional) – Number of heads in the multi-head attention layer.
mlp_ratio (float, optional) – Size of the dense layers of the final classifier. This value will mutiply
embed_dim
.decoder_embed_dim (int, optional) – Size of the transformer embedding in the decoder.
decoder_depth (int, optional) – Number of layers of the decoder.
decoder_num_heads (int, optional) – Number of heads in the multi-head attention layer in the decoder.
norm_layer (Torch layer, optional) – Normalization layer to use in the model.
norm_pix_loss (bool, optional) – Use (per-patch) normalized pixels as targets for computing loss
mask_ratio (float, optional) – Percentage of the input image to mask. Value between 0 and 1.
device (Torch device) – Device used.
- Returns:
model – MAE model.
- Return type:
Torch model
- patchify(imgs)[source]
Create patches from input image. Opposite function of
unpatchify()
.- Parameters:
imgs (Tensor) – Input images. In 2D:
(N, C, H, W)
, in 3D:(N, C, Z, H, W)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width.- Returns:
x – MAE model. in 2D:
(N, L, patch_size**2 *C)
in 3D:(N, L, patch_size**3 *C)
. WhereN
is the batch size,L
is the multiplication of dimension (i.e.Z
,H
andW
) andC
are the channels.- Return type:
Torch tensor
- unpatchify(x)[source]
Create original image shape from input patches. Opposite function of
patchify()
.- Parameters:
x (Tensor) – Input images. In 2D:
(N, L, patch_size**2 *C)
, in 3D:(N, L, patch_size**3 *C)
. WhereN
is the batch size,L
is the multiplication of dimension (i.e.Z
,H
andW
) andC
are the channels.- Returns:
imgs – MAE model. in 2D:
(N, C, H, W)
in 3D:(N, C, Z, H, W)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width.- Return type:
Torch tensor
- random_masking(x)[source]
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise.
- Parameters:
x (Tensor) – Input images. Is shape is
(N, L, D)
shape. WhereN
is the batch size,L
is the multiplication of dimension (i.e.Z
,H
andW
) andD
isembed_dim
.
- grid_masking(x)[source]
Perform grid masking for each sample.
- Parameters:
x (Tensor) – Input images. Is shape is
(N, L, D)
shape. WhereN
is the batch size,L
is the multiplication of dimension (i.e.Z
,H
andW
) andD
isembed_dim
.
- forward_loss(imgs, pred, mask)[source]
MAE loss calculation.
- Parameters:
imgs (Tensor) – Input images. In 2D:
(N, C, H, W)
, in 3D:(N, C, Z, H, W)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width.pred (Tensor) – Predictions. In 2D:
(N, L, patch_size**2 *C)
, in 3D:(N, L, patch_size**3 *C)
. WhereN
is the batch size,L
is the multiplication of dimension (i.e.Z
,H
andW
) andC
are the channels.mask (2d array) – Information of which patches will be retain and masked. Shape is:
(N, L)
where0
is keep and1
is remove.
- Returns:
loss – Calculated loss on masked patches only.
- Return type:
Tensor
- forward(imgs)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- save_images(_x, _y, _mask, dtype)[source]
Save images from MAE.
- Parameters:
_x (Torch tensor) – Input images. In 2D:
(N, C, H, W)
, in 3D:(N, C, Z, H, W)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width._y (Torch tensor) – Input images. In 2D:
(N, L, patch_size**2 *C)
, in 3D:(N, L, patch_size**3 *C)
. WhereN
is the batch size,L
is the multiplication of dimension (i.e.Z
,H
andW
) andC
are the channels._mask (2d array) – Information of which patches will be retain and masked. Shape is:
(N, L)
where0
is keep and1
is remove.dtype (Numpy dtype) – Dtype to save the images.
- Returns:
pred (4D/5D Numpy array) – Predicted images converted to Numpy. In 2D:
(N, H, W, C)
, in 3D:(N, Z, H, W, C)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width.p_mask (4D/5D Numpy array) – Predicted images’s mask. In 2D:
(N, H, W, C)
, in 3D:(N, Z, H, W, C)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width.pred_visi (4D/5D Numpy array) – Predicted image with visible patches. In 2D:
(N, H, W, C)
, in 3D:(N, Z, H, W, C)
. WhereN
is the batch size,C
are the channels,Z
image depth,H
image height andW
image’s width.
- biapy.models.mae.mae_vit_base_patch16(**kwargs)
- biapy.models.mae.mae_vit_large_patch16(**kwargs)
- biapy.models.mae.mae_vit_huge_patch14(**kwargs)