biapy.models.hrnet

This file implements the High-Resolution Net (HRNet) model and its core building blocks, designed for dense prediction tasks in 2D and 3D imaging.

The HRNet architecture maintains high-resolution representations throughout the network by connecting high-to-low resolution convolution streams in parallel and facilitating repeated information exchange across these streams.

Key components:

  • HighResolutionNet: The main HRNet model.

  • HighResolutionModule: Core HRNet building block that manages multi-resolution fusion.

  • HRBasicBlock: Basic residual block for HRNet.

  • HRBottleneck: Bottleneck residual block for HRNet.

Reference: Deep high-resolution representation learning for visual recognition

Code adapted from: Exploring Cross-Image Pixel Contrast for Semantic Segmentation

class biapy.models.hrnet.HighResolutionModule(ndim: int, num_branches: int, blocks: Type[HRBasicBlock | HRBottleneck], num_blocks: List[int], num_inchannels: List[int], num_channels: List[int], multi_scale_output: bool = True, norm: str = 'none', branch_strides: List[Tuple[int, ...]] | None = None, activation: str = 'relu')[source]

Bases: Module

get_num_inchannels()[source]

Retrieve the current number of input channels for each branch.

This method provides access to the dynamically updated num_inchannels list, which reflects the channel counts of features after they have passed through the respective blocks within this module. This is useful for configuring subsequent stages or modules.

Returns:

A list where each element represents the number of channels for the corresponding branch’s output.

Return type:

List[int]

forward(x)[source]

Perform the forward pass of the High Resolution Module.

The input is a list of tensors, where each tensor corresponds to a feature map from a parallel resolution branch. Each feature map first passes through its respective branch’s convolutional blocks. Then, the outputs from all branches are fused by upsampling or downsampling as necessary, followed by element-wise summation to create new feature maps at each target resolution.

Parameters:

x (List[torch.Tensor]) – A list of input feature tensors, where each tensor corresponds to a different resolution branch. The order typically goes from highest to lowest resolution.

Returns:

A list of output feature tensors, representing the fused and processed features at potentially multiple scales. If multi_scale_output is True, the list will contain features for all output resolutions; otherwise, it might contain only the highest resolution output.

Return type:

List[torch.Tensor]

class biapy.models.hrnet.HighResolutionNet(cfg: Dict, image_shape: Tuple[int, ...] = (256, 256, 1), normalization: str = 'none', output_channels: List[int] = [1], output_channel_info=['F'], explicit_activations: bool = False, head_activations: List[str] = ['ce_sigmoid'], contrast: bool = False, contrast_proj_dim: int = 256, head_type: str = 'FCN', activation: str = 'relu', return_one_tensor: bool = False)[source]

Bases: Module

forward(input) Dict | Tensor[source]

Perform the forward pass of the HighResolutionNet.

The input x first goes through initial convolutional blocks. Then, it propagates through a series of HRNet stages, where feature maps are processed in parallel across multiple resolutions and information is exchanged. Finally, features from all resolutions are fused, and passed through a final prediction head. Optionally, a contrastive learning projection head and/or a multi-head classification output can be included.

Parameters:

x (torch.Tensor) – The input image tensor. Expected shape for 2D: (batch_size, channels, height, width). Expected shape for 3D: (batch_size, channels, depth, height, width).

Returns:

If there is only one output head, returns a tensor with the predictions. If there are multiple output heads (e.g. for multi-task learning), returns a dictionary with keys:

  • ”pred”: tensor with the main predictions (e.g. segmentation map)

  • ”class”: tensor with the classification output (if return_class is True)

  • ”embed”: tensor with the contrastive learning embedding (if contrast is True)

Return type:

Dict or torch.Tensor