Source code for biapy.models.vit

"""
This module implements a Vision Transformer (ViT) model, extending the `timm` library's `VisionTransformer` to support custom functionalities, particularly for different input dimensionalities (2D and 3D) and global pooling options.

The Vision Transformer processes images by dividing them into fixed-size patches,
linearly embedding each patch, and then processing the resulting sequence of
embeddings with a standard Transformer encoder. This module is often used as
a backbone for various computer vision tasks, including classification and
self-supervised learning (e.g., Masked Autoencoders).

Classes:

- ``VisionTransformer``: An extended ViT model with support for 2D/3D inputs and global pooling.

Functions:

- ``vit_base_patch16``: Factory function for a base-sized ViT model with 16x16 patches.
- ``vit_large_patch16``: Factory function for a large-sized ViT model with 16x16 patches.
- ``vit_huge_patch14``: Factory function for a huge-sized ViT model with 14x14 patches.

References:

- timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
- DeiT: https://github.com/facebookresearch/deit
- Masked Autoencoders Are Scalable Vision Learners: https://arxiv.org/abs/2111.06377
"""

from functools import partial
from typing import List

import torch
import torch.nn as nn
import timm.models.vision_transformer

from biapy.models.blocks import prepare_activation_layers
from biapy.models.tr_layers import PatchEmbed


[docs] class VisionTransformer(timm.models.vision_transformer.VisionTransformer): """ Vision Transformer (ViT) model with extensions for 2D/3D input and global pooling. This class inherits from `timm.models.vision_transformer.VisionTransformer` and customizes it by replacing the default `patch_embed` with a custom implementation (`biapy.models.tr_layers.PatchEmbed`) that supports 2D and 3D inputs. It also adds an option for global average pooling of patch tokens for classification. Reference: `Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>`_. Parameters ---------- ndim : int, optional Number of input dimensions (2 for 2D images, 3 for 3D images). Defaults to 2. global_pool : bool, optional If True, applies global average pooling to the patch tokens (excluding the class token) before the final classification head. If False, uses the class token's output for classification (standard ViT behavior). Defaults to False. **kwargs Arbitrary keyword arguments passed to the base `timm.models.vision_transformer.VisionTransformer` constructor, such as `img_size`, `patch_size`, `in_chans`, `embed_dim`, `depth`, `num_heads`, `mlp_ratio`, `qkv_bias`, `norm_layer`, etc. """ def __init__( self, ndim: int = 2, global_pool: bool = False, head_activations: List[str] = ["ce_sigmoid"], output_channel_info: List[str] = ["F"], explicit_activations: bool = False, **kwargs ): """ Initialize the VisionTransformer model. Calls the base `timm.models.vision_transformer.VisionTransformer` constructor, then customizes the `patch_embed` layer and handles the global pooling configuration, potentially removing the original normalization layer if global pooling is enabled. Parameters ---------- ndim : int, optional Number of input dimensions (2 for 2D images, 3 for 3D images). Defaults to 2. global_pool : bool, optional If True, enables global average pooling of patch tokens. Defaults to False. explicit_activations : bool, optional If True, enables explicit activation functions. Defaults to False. **kwargs Keyword arguments to pass to the parent `timm.models.vision_transformer.VisionTransformer` constructor. These typically include `img_size`, `patch_size`, `in_chans`, `embed_dim`, `depth`, `num_heads`, `mlp_ratio`, `qkv_bias`, `norm_layer`. """ super(VisionTransformer, self).__init__(**kwargs) self.ndim = ndim self.global_pool = global_pool self.explicit_activations = explicit_activations if self.explicit_activations: self.class_head_activations, _ = prepare_activation_layers(head_activations, output_channel_info, [self.num_classes]) if self.global_pool: norm_layer = partial(nn.LayerNorm, eps=1e-6) embed_dim = kwargs["embed_dim"] self.fc_norm = norm_layer(embed_dim) del self.norm # remove the original norm # Replace with our PatchEmbed implementation and re-define all dependant variables self.patch_embed = PatchEmbed( img_size=kwargs["img_size"], patch_size=kwargs["patch_size"], in_chans=kwargs["in_chans"], ndim=self.ndim, embed_dim=kwargs["embed_dim"], bias=True, ) num_patches = self.patch_embed.num_patches embed_len = num_patches if self.no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, kwargs["embed_dim"]) * 0.02)
[docs] def forward_features(self, x): """ Perform the forward pass through the Vision Transformer's encoder. This method processes the input image, converts it into patch embeddings, adds positional embeddings and the class token, and then passes the sequence through the transformer blocks. Finally, it applies either global pooling or extracts the class token's output based on `self.global_pool`. Parameters ---------- x : torch.Tensor The input image tensor. Expected shape for 2D: `(batch_size, channels, height, width)`. Expected shape for 3D: `(batch_size, channels, depth, height, width)`. Returns ------- torch.Tensor The output feature representation from the ViT encoder. - If `global_pool` is True: `(batch_size, embed_dim)` (pooled patch tokens). - If `global_pool` is False: `(batch_size, embed_dim)` (class token output). """ B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome
[docs] def forward(self, x): outs = super(VisionTransformer, self).forward(x) # Apply activations to the output heads if explicit_activations is True if self.explicit_activations: outs = self.class_head_activations[0](outs) return outs
[docs] def vit_base_patch16(**kwargs): """ Create a Vision Transformer (ViT) model with a Base-sized encoder and 16x16 patches. This function serves as a convenient constructor for a specific ViT configuration, often used as a standard baseline. Parameters ---------- **kwargs Arbitrary keyword arguments to be passed to the `VisionTransformer` constructor. This allows overriding default parameters like `img_size`, `in_chans`, `ndim`, `global_pool`, etc. Returns ------- model : VisionTransformer An initialized Base-sized ViT model. """ model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs, ) return model
[docs] def vit_large_patch16(**kwargs): """ Create a Vision Transformer (ViT) model with a Large-sized encoder and 16x16 patches. This function provides a constructor for a larger ViT configuration, suitable for tasks requiring more capacity. Parameters ---------- **kwargs Arbitrary keyword arguments to be passed to the `VisionTransformer` constructor. This allows overriding default parameters like `img_size`, `in_chans`, `ndim`, `global_pool`, etc. Returns ------- model : VisionTransformer An initialized Large-sized ViT model. """ model = VisionTransformer( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs, ) return model
[docs] def vit_huge_patch14(**kwargs): """ Create a Vision Transformer (ViT) model with a Huge-sized encoder and 14x14 patches. This function provides a constructor for the largest ViT configuration, designed for tasks demanding maximum model capacity. Parameters ---------- **kwargs Arbitrary keyword arguments to be passed to the `VisionTransformer` constructor. This allows overriding default parameters like `img_size`, `in_chans`, `ndim`, `global_pool`, etc. Returns ------- model : VisionTransformer An initialized Huge-sized ViT model. """ model = VisionTransformer( patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs, ) return model