biapy.models.rcan
This module implements the Deep Residual Channel Attention Networks (RCAN) model, a prominent architecture for image super-resolution.
RCAN leverages very deep residual networks combined with channel attention mechanisms to achieve high-quality image reconstruction. The model is built upon several key components:
Classes:
ChannelAttention: Implements a channel attention mechanism that recalibrates channel-wise feature responses by modeling interdependencies between channels.RCAB(Residual Channel Attention Block): A fundamental building block that combines residual learning with the ChannelAttention mechanism.RG(Residual Group): A collection of multiple RCABs, followed by a convolutional layer, with a global residual connection.rcan: The main RCAN model, integrating the initial feature extraction, multiple Residual Groups, and an optional upscaling module for super-resolution.
The implementation supports both 2D and 3D image inputs and is adapted from the official RCAN-pytorch repository.
Reference: Image Super-Resolution Using Very Deep Residual Channel Attention Networks.
Adapted from: https://github.com/yjn870/RCAN-pytorch
- class biapy.models.rcan.ChannelAttention(num_features, reduction, ndim=2)[source]
Bases:
ModuleImplements a Channel Attention mechanism.
This module recalibrates channel-wise feature responses by adaptively learning the interdependencies between channels. It uses global average pooling to compute channel-wise statistics, followed by a small MLP (two 1x1 convolutions with SiLU and Sigmoid activations) to predict channel-wise scaling factors.
- forward(x)[source]
Perform the forward pass of the ChannelAttention module.
Computes channel attention weights from the input x and then multiplies x element-wise by these weights, effectively recalibrating the features.
- Parameters:
x (torch.Tensor) – The input feature tensor. Expected shape for 2D: (batch_size, num_features, H, W). Expected shape for 3D: (batch_size, num_features, D, H, W).
- Returns:
The feature tensor after applying channel attention. Same shape as input x.
- Return type:
torch.Tensor
- class biapy.models.rcan.RCAB(num_features, reduction, ndim=2)[source]
Bases:
ModuleResidual Channel Attention Block (RCAB).
This block is a fundamental building unit of RCAN. It combines a residual connection with two convolutional layers and a ChannelAttention module to enhance feature learning and improve reconstruction quality.
- forward(x)[source]
Perform the forward pass of the Residual Channel Attention Block.
The input x is processed through a sequence of convolutions and channel attention. The output of this sequence is then added back to the original input x via a residual connection.
- Parameters:
x (torch.Tensor) – The input feature tensor to the RCAB.
- Returns:
The output feature tensor after applying the RCAB. Same shape as input x.
- Return type:
torch.Tensor
- class biapy.models.rcan.RG(num_features, num_rcab, reduction, ndim=2)[source]
Bases:
ModuleResidual Group (RG).
A Residual Group consists of multiple Residual Channel Attention Blocks (RCABs) followed by a convolutional layer. It incorporates a global residual connection around the entire group, allowing for the construction of very deep networks.
- forward(x)[source]
Perform the forward pass of the Residual Group.
The input x is processed through the sequence of RCABs and the final convolutional layer. The output of this sequence is then added back to the original input x via a global residual connection.
- Parameters:
x (torch.Tensor) – The input feature tensor to the Residual Group.
- Returns:
The output feature tensor after processing through the Residual Group. Same shape as input x.
- Return type:
torch.Tensor
- class biapy.models.rcan.rcan(ndim, num_channels=3, filters=64, scale=2, num_rg=10, num_rcab=20, reduction=16, upscaling_layer=True)[source]
Bases:
ModuleDeep Residual Channel Attention Networks (RCAN) model.
RCAN is a very deep residual network designed for image super-resolution. It utilizes Residual Groups (RG) composed of Residual Channel Attention Blocks (RCABs) to learn hierarchical features and enhance reconstruction quality by focusing on informative channels. The model includes an initial feature extraction layer, multiple RGs, a global skip connection, and an optional upscaling layer.
Reference: Image Super-Resolution Using Very Deep Residual Channel Attention Networks.
Adapted from here.
- forward(x) Tensor[source]
Perform the forward pass of the RCAN model.
The input x first undergoes shallow feature extraction. Then, it passes through a sequence of Residual Groups (RGs). A global residual connection adds the output of the RGs to the initial features. Optionally, an upscaling layer is applied, followed by a final convolutional layer to produce the super-resolved output.
- Parameters:
x (torch.Tensor) – The input image tensor. Expected shape for 2D: (batch_size, num_channels, H, W). Expected shape for 3D: (batch_size, num_channels, D, H, W).
- Returns:
The super-resolved output image tensor. If upscaling_layer is True, its spatial dimensions will be scaled by the scale factor. The number of channels will match num_channels.
- Return type:
torch.Tensor