Source code for biapy.models.rcan

"""
This module implements the Deep Residual Channel Attention Networks (RCAN) model, a prominent architecture for image super-resolution.

RCAN leverages very deep residual networks combined with channel attention
mechanisms to achieve high-quality image reconstruction. The model is built
upon several key components:

Classes:

- ``ChannelAttention``: Implements a channel attention mechanism that recalibrates
  channel-wise feature responses by modeling interdependencies between channels.
- ``RCAB`` (Residual Channel Attention Block): A fundamental building block that
  combines residual learning with the ChannelAttention mechanism.
- ``RG`` (Residual Group): A collection of multiple RCABs, followed by a
  convolutional layer, with a global residual connection.
- ``rcan``: The main RCAN model, integrating the initial feature extraction,
  multiple Residual Groups, and an optional upscaling module for super-resolution.

The implementation supports both 2D and 3D image inputs and is adapted from
the official RCAN-pytorch repository.

Reference:
`Image Super-Resolution Using Very Deep Residual Channel Attention Networks
<https://openaccess.thecvf.com/content_ECCV_2018/html/Yulun_Zhang_Image_Super-Resolution_Using_ECCV_2018_paper.html>`_.

Adapted from:
https://github.com/yjn870/RCAN-pytorch
"""
from typing import Sequence

import torch
from torch import nn

[docs] class ChannelAttention(nn.Module): """ Implements a Channel Attention mechanism. This module recalibrates channel-wise feature responses by adaptively learning the interdependencies between channels. It uses global average pooling to compute channel-wise statistics, followed by a small MLP (two 1x1 convolutions with SiLU and Sigmoid activations) to predict channel-wise scaling factors. """ def __init__(self, num_features, reduction, ndim=2): """ Initialize the ChannelAttention module. Sets up the adaptive pooling layer and the sequential convolutional layers that form the core of the channel attention mechanism. The choice between 2D and 3D layers depends on `ndim`. Parameters ---------- num_features : int The number of input and output channels for the attention module. reduction : int The reduction ratio for the intermediate channel dimension in the MLP-like structure. A higher reduction leads to a smaller model but might reduce expressive power. ndim : int, optional The number of spatial dimensions of the input data (2 for 2D, 3 for 3D). Defaults to 2. """ super(ChannelAttention, self).__init__() if ndim == 2: conv = nn.Conv2d avg_pool = nn.AdaptiveAvgPool2d else: conv = nn.Conv3d avg_pool = nn.AdaptiveAvgPool3d self.module = nn.Sequential( avg_pool(1), conv(num_features, num_features // reduction, kernel_size=1), nn.SiLU(inplace=True), conv(num_features // reduction, num_features, kernel_size=1), nn.Sigmoid(), )
[docs] def forward(self, x): """ Perform the forward pass of the ChannelAttention module. Computes channel attention weights from the input `x` and then multiplies `x` element-wise by these weights, effectively recalibrating the features. Parameters ---------- x : torch.Tensor The input feature tensor. Expected shape for 2D: `(batch_size, num_features, H, W)`. Expected shape for 3D: `(batch_size, num_features, D, H, W)`. Returns ------- torch.Tensor The feature tensor after applying channel attention. Same shape as input `x`. """ return x * self.module(x)
[docs] class RCAB(nn.Module): """ Residual Channel Attention Block (RCAB). This block is a fundamental building unit of RCAN. It combines a residual connection with two convolutional layers and a ChannelAttention module to enhance feature learning and improve reconstruction quality. """ def __init__(self, num_features, reduction, ndim=2): """ Initialize the Residual Channel Attention Block. Sets up two convolutional layers and integrates a `ChannelAttention` module within a residual learning framework. The choice of 2D or 3D convolution depends on `ndim`. Parameters ---------- num_features : int The number of input and output channels for the convolutional layers and the `ChannelAttention` module within the block. reduction : int The reduction ratio passed to the `ChannelAttention` module. ndim : int, optional The number of spatial dimensions of the input data (2 for 2D, 3 for 3D). Defaults to 2. """ super(RCAB, self).__init__() if ndim == 2: conv = nn.Conv2d else: conv = nn.Conv3d self.module = nn.Sequential( conv(num_features, num_features, kernel_size=3, padding="same"), nn.SiLU(inplace=True), conv(num_features, num_features, kernel_size=3, padding="same"), ChannelAttention(num_features, reduction, ndim=ndim), )
[docs] def forward(self, x): """ Perform the forward pass of the Residual Channel Attention Block. The input `x` is processed through a sequence of convolutions and channel attention. The output of this sequence is then added back to the original input `x` via a residual connection. Parameters ---------- x : torch.Tensor The input feature tensor to the RCAB. Returns ------- torch.Tensor The output feature tensor after applying the RCAB. Same shape as input `x`. """ return x + self.module(x)
[docs] class RG(nn.Module): """ Residual Group (RG). A Residual Group consists of multiple Residual Channel Attention Blocks (RCABs) followed by a convolutional layer. It incorporates a global residual connection around the entire group, allowing for the construction of very deep networks. """ def __init__(self, num_features, num_rcab, reduction, ndim=2): """ Initialize a Residual Group. Constructs a sequence of `num_rcab` RCABs and appends a final convolutional layer. The entire sequence is wrapped in `nn.Sequential`. Parameters ---------- num_features : int The number of features (channels) processed throughout the group. num_rcab : int The number of `RCAB` blocks to include in this Residual Group. reduction : int The reduction ratio passed to each `RCAB`'s `ChannelAttention` module. ndim : int, optional The number of spatial dimensions (2 for 2D, 3 for 3D). Defaults to 2. """ super(RG, self).__init__() if ndim == 2: conv = nn.Conv2d else: conv = nn.Conv3d self.module = [RCAB(num_features, reduction, ndim=ndim) for _ in range(num_rcab)] self.module.append(conv(num_features, num_features, kernel_size=3, padding="same")) self.module = nn.Sequential(*self.module)
[docs] def forward(self, x): """ Perform the forward pass of the Residual Group. The input `x` is processed through the sequence of RCABs and the final convolutional layer. The output of this sequence is then added back to the original input `x` via a global residual connection. Parameters ---------- x : torch.Tensor The input feature tensor to the Residual Group. Returns ------- torch.Tensor The output feature tensor after processing through the Residual Group. Same shape as input `x`. """ return x + self.module(x)
[docs] class rcan(nn.Module): """ Deep Residual Channel Attention Networks (RCAN) model. RCAN is a very deep residual network designed for image super-resolution. It utilizes Residual Groups (RG) composed of Residual Channel Attention Blocks (RCABs) to learn hierarchical features and enhance reconstruction quality by focusing on informative channels. The model includes an initial feature extraction layer, multiple RGs, a global skip connection, and an optional upscaling layer. Reference: `Image Super-Resolution Using Very Deep Residual Channel Attention Networks <https://openaccess.thecvf.com/content_ECCV_2018/html/Yulun_Zhang_Image_Super-Resolution_Using_ECCV_2018_paper.html>`_. Adapted from `here <https://github.com/yjn870/RCAN-pytorch>`_. """ def __init__( self, ndim, num_channels=3, filters=64, scale=2, num_rg=10, num_rcab=20, reduction=16, upscaling_layer=True, ): """ Initialize the RCAN model. Sets up the initial shallow feature extraction layer, a sequence of Residual Groups (RGs), a global convolutional layer, an optional upscaling module (using PixelShuffle), and the final reconstruction layer. The choice between 2D and 3D layers depends on `ndim`. Parameters ---------- ndim : int The number of spatial dimensions of the input data (2 for 2D, 3 for 3D). num_channels : int, optional The number of input and output image channels (e.g., 3 for RGB). Defaults to 3. filters : int, optional The number of feature maps (channels) used throughout the main body of the network (e.g., within RGs and RCABs). Defaults to 64. scale : int | Tuple[int, ...], optional The super-resolution upscaling factor. If a tuple, only the first element is used as `PixelShuffle` expects a single integer factor. Defaults to 2. num_rg : int, optional The number of Residual Groups (RGs) to stack in the network. Defaults to 10. num_rcab : int, optional The number of RCABs within each Residual Group. Defaults to 20. reduction : int, optional The reduction ratio for the `ChannelAttention` module within each RCAB. Defaults to 16. upscaling_layer : bool, optional If True, an upscaling layer (using PixelShuffle) is included before the final convolutional layer to perform super-resolution. If False, the model outputs at the same resolution as the input features. Defaults to True. """ super(rcan, self).__init__() if type(scale) is not int and isinstance(scale, Sequence): scale = scale[0] self.ndim = ndim self.upscaling_layer = upscaling_layer if ndim == 2: conv = nn.Conv2d else: conv = nn.Conv3d # Shallow Feature Extraction (SF) self.sf = conv(num_channels, filters, kernel_size=3, padding="same") # Residual Groups (RGs) self.rgs = nn.Sequential(*[RG(filters, num_rcab, reduction, ndim=ndim) for _ in range(num_rg)]) # Global Skip Connection Convolution self.conv1 = conv(filters, filters, kernel_size=3, padding="same") # Optional Upscaling Layer if upscaling_layer: self.upscale = nn.Sequential( conv(filters, filters * (scale**2), kernel_size=3, padding="same"), nn.PixelShuffle(scale), ) # Final Reconstruction Layer self.conv2 = conv(filters, num_channels, kernel_size=3, padding="same")
[docs] def forward(self, x) -> torch.Tensor: """ Perform the forward pass of the RCAN model. The input `x` first undergoes shallow feature extraction. Then, it passes through a sequence of Residual Groups (RGs). A global residual connection adds the output of the RGs to the initial features. Optionally, an upscaling layer is applied, followed by a final convolutional layer to produce the super-resolved output. Parameters ---------- x : torch.Tensor The input image tensor. Expected shape for 2D: `(batch_size, num_channels, H, W)`. Expected shape for 3D: `(batch_size, num_channels, D, H, W)`. Returns ------- torch.Tensor The super-resolved output image tensor. If `upscaling_layer` is True, its spatial dimensions will be scaled by the `scale` factor. The number of channels will match `num_channels`. """ x = self.sf(x) residual = x x = self.rgs(x) x = self.conv1(x) x += residual if self.upscaling_layer: x = self.upscale(x) x = self.conv2(x) return x