biapy.models.tr_layers

This module provides the PatchEmbed class, a fundamental component used in Vision Transformers (ViT) to convert raw image data into sequences of flattened patches (tokens) suitable for transformer processing.

The PatchEmbed class handles the projection of image pixels into a higher-dimensional embedding space and the subsequent flattening and optional normalization of these patches. It supports both 2D and 3D image inputs.

Classes:

  • PatchEmbed: Transforms an input image into a sequence of embedded patches.

class biapy.models.tr_layers.PatchEmbed(img_size: int = 224, patch_size: int = 16, in_chans: int = 3, ndim: int = 2, embed_dim: int = 768, norm_layer: Callable | None = None, flatten: bool = True, bias: bool = True, strict_img_size: bool = True)[source]

Bases: Module

Image to Patch Embedding module.

This module converts an input image into a sequence of non-overlapping patches and then projects these patches into a higher-dimensional embedding space. Optionally, it applies normalization to the embeddings. It is a core component for Vision Transformers (ViT).

forward(x)[source]

Perform the forward pass of the PatchEmbed module.

Projects the input image into patches, optionally flattens them into a sequence, and applies normalization. It also includes an assertion for strict image size matching if strict_img_size is True.

Parameters:

x (torch.Tensor) – The input image tensor.

  • For 2D: (batch_size, channels, height, width)

  • For 3D: (batch_size, channels, depth, height, width)

Returns:

The embedded patches.

  • If flatten is True: (batch_size, num_patches, embed_dim)

  • If flatten is False: (batch_size, embed_dim, grid_size_D, grid_size_H, grid_size_W) (spatial dimensions will be img_size / patch_size)

Return type:

torch.Tensor

Raises:

AssertionError – If strict_img_size is True and the input image dimensions do not match the img_size specified during initialization.