biapy.models.vit

This module implements a Vision Transformer (ViT) model, extending the timm library’s VisionTransformer to support custom functionalities, particularly for different input dimensionalities (2D and 3D) and global pooling options.

The Vision Transformer processes images by dividing them into fixed-size patches, linearly embedding each patch, and then processing the resulting sequence of embeddings with a standard Transformer encoder. This module is often used as a backbone for various computer vision tasks, including classification and self-supervised learning (e.g., Masked Autoencoders).

Classes:

  • VisionTransformer: An extended ViT model with support for 2D/3D inputs and global pooling.

Functions:

  • vit_base_patch16: Factory function for a base-sized ViT model with 16x16 patches.

  • vit_large_patch16: Factory function for a large-sized ViT model with 16x16 patches.

  • vit_huge_patch14: Factory function for a huge-sized ViT model with 14x14 patches.

References:

class biapy.models.vit.VisionTransformer(ndim: int = 2, global_pool: bool = False, head_activations: List[str] = ['ce_sigmoid'], output_channel_info: List[str] = ['F'], explicit_activations: bool = False, **kwargs)[source]

Bases: VisionTransformer

Vision Transformer (ViT) model with extensions for 2D/3D input and global pooling.

This class inherits from timm.models.vision_transformer.VisionTransformer and customizes it by replacing the default patch_embed with a custom implementation (biapy.models.tr_layers.PatchEmbed) that supports 2D and 3D inputs. It also adds an option for global average pooling of patch tokens for classification.

Reference: Masked Autoencoders Are Scalable Vision Learners.

Parameters:
  • ndim (int, optional) – Number of input dimensions (2 for 2D images, 3 for 3D images). Defaults to 2.

  • global_pool (bool, optional) – If True, applies global average pooling to the patch tokens (excluding the class token) before the final classification head. If False, uses the class token’s output for classification (standard ViT behavior). Defaults to False.

  • **kwargs – Arbitrary keyword arguments passed to the base timm.models.vision_transformer.VisionTransformer constructor, such as img_size, patch_size, in_chans, embed_dim, depth, num_heads, mlp_ratio, qkv_bias, norm_layer, etc.

forward_features(x)[source]

Perform the forward pass through the Vision Transformer’s encoder.

This method processes the input image, converts it into patch embeddings, adds positional embeddings and the class token, and then passes the sequence through the transformer blocks. Finally, it applies either global pooling or extracts the class token’s output based on self.global_pool.

Parameters:

x (torch.Tensor) – The input image tensor. Expected shape for 2D: (batch_size, channels, height, width). Expected shape for 3D: (batch_size, channels, depth, height, width).

Returns:

The output feature representation from the ViT encoder. - If global_pool is True: (batch_size, embed_dim) (pooled patch tokens). - If global_pool is False: (batch_size, embed_dim) (class token output).

Return type:

torch.Tensor

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

biapy.models.vit.vit_base_patch16(**kwargs)[source]

Create a Vision Transformer (ViT) model with a Base-sized encoder and 16x16 patches.

This function serves as a convenient constructor for a specific ViT configuration, often used as a standard baseline.

Parameters:

**kwargs – Arbitrary keyword arguments to be passed to the VisionTransformer constructor. This allows overriding default parameters like img_size, in_chans, ndim, global_pool, etc.

Returns:

model – An initialized Base-sized ViT model.

Return type:

VisionTransformer

biapy.models.vit.vit_large_patch16(**kwargs)[source]

Create a Vision Transformer (ViT) model with a Large-sized encoder and 16x16 patches.

This function provides a constructor for a larger ViT configuration, suitable for tasks requiring more capacity.

Parameters:

**kwargs – Arbitrary keyword arguments to be passed to the VisionTransformer constructor. This allows overriding default parameters like img_size, in_chans, ndim, global_pool, etc.

Returns:

model – An initialized Large-sized ViT model.

Return type:

VisionTransformer

biapy.models.vit.vit_huge_patch14(**kwargs)[source]

Create a Vision Transformer (ViT) model with a Huge-sized encoder and 14x14 patches.

This function provides a constructor for the largest ViT configuration, designed for tasks demanding maximum model capacity.

Parameters:

**kwargs – Arbitrary keyword arguments to be passed to the VisionTransformer constructor. This allows overriding default parameters like img_size, in_chans, ndim, global_pool, etc.

Returns:

model – An initialized Huge-sized ViT model.

Return type:

VisionTransformer