"""
This file implements the High-Resolution Net (HRNet) model and its core building blocks,
designed for dense prediction tasks in 2D and 3D imaging.
The HRNet architecture maintains high-resolution representations throughout the
network by connecting high-to-low resolution convolution streams in parallel
and facilitating repeated information exchange across these streams.
Key components:
- ``HighResolutionNet``: The main HRNet model.
- ``HighResolutionModule``: Core HRNet building block that manages multi-resolution fusion.
- ``HRBasicBlock``: Basic residual block for HRNet.
- ``HRBottleneck``: Bottleneck residual block for HRNet.
Reference:
`Deep high-resolution representation learning for visual recognition <https://ieeexplore.ieee.org/abstract/document/9052469/>`_
Code adapted from:
`Exploring Cross-Image Pixel Contrast for Semantic Segmentation <https://github.com/tfzhou/ContrastiveSeg/tree/main>`_
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Type
from biapy.models.blocks import (
HRBasicBlock,
HRBottleneck,
ConvBlock,
get_norm_3d,
get_norm_2d,
ConvNeXtBlock_V2,
ConvNeXtBlock_V1,
prepare_activation_layers,
get_activation,
)
from biapy.models.heads import ASPP, ProjectionHead, PSP, OCRHead
[docs]
class HighResolutionModule(nn.Module):
def __init__(
self,
ndim: int,
num_branches: int,
blocks: Type[HRBasicBlock | HRBottleneck],
num_blocks: List[int],
num_inchannels: List[int],
num_channels: List[int],
multi_scale_output: bool = True,
norm: str = "none",
branch_strides: List[Tuple[int, ...]] = None,
activation: str = "relu",
):
"""
Initialize a High Resolution Module.
Sets up the parallel branches with their respective convolutional blocks
and constructs the information fusion layers that allow features to be
exchanged across different resolutions. It also determines the
appropriate convolutional and normalization layers based on the
dimensionality (`ndim`).
Parameters
----------
ndim : int
Number of spatial dimensions of the input (2 for 2D, 3 for 3D).
num_branches : int
The number of parallel resolution branches within this module.
blocks : Type[HRBasicBlock | HRBottleneck]
The class of residual block to be used within the branches.
num_blocks : List[int]
A list where each element specifies the number of `blocks` for the
corresponding branch. Its length must match `num_branches`.
num_inchannels : List[int]
A list where each element specifies the input channel count for the
corresponding branch. Its length must match `num_branches`.
num_channels : List[int]
A list where each element specifies the output channel count for the
corresponding branch after processing through its blocks. Its length
must match `num_branches`.
multi_scale_output : bool, optional
If True, the module's forward pass will output features at all scales
by fusing and returning all branch outputs. If False, only a single
(typically high-resolution) output might be expected, depending on
subsequent processing. Defaults to True.
norm : str, optional
The type of normalization layer to apply within the module's blocks
and fusion layers (e.g., 'bn', 'sync_bn', 'in', 'gn', 'none').
Defaults to "none".
branch_strides : List[Tuple[int, ...]], optional
The strides for each branch, primarily used in the fusion layers when
features from higher resolution branches are downsampled to match lower
resolution ones. Defaults to None.
activation : str, optional
The activation function to use within the blocks and fusion layers.
Defaults to "relu".
Raises
------
ValueError
If the lengths of `num_blocks`, `num_inchannels`, or `num_channels`
do not match `num_branches`.
"""
super(HighResolutionModule, self).__init__()
self.ndim = ndim
if self.ndim == 3:
self.conv_call = nn.Conv3d
self.norm_func = get_norm_3d
else:
self.conv_call = nn.Conv2d
self.norm_func = get_norm_2d
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branch_strides = branch_strides
self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels, norm=norm)
self.activation_str = activation
self.activation = get_activation(activation)
self.fuse_layers = self._make_fuse_layers(norm=norm)
def _check_branches(
self, num_branches: int, num_blocks: List[int], num_inchannels: List[int], num_channels: List[int]
):
"""
Check if the number of branches, blocks, input channels and output channels are consistent.
Parameters
----------
num_branches : int
Number of branches in the module.
num_blocks : List[int]
Number of blocks in each branch.
num_inchannels : List[int]
Number of input channels for each branch.
num_channels : List[int]
Number of output channels for each branch.
"""
if num_branches != len(num_blocks):
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(num_branches, len(num_channels))
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(num_branches, len(num_inchannels))
raise ValueError(error_msg)
def _make_one_branch(
self,
branch_index: int,
block: Type[HRBasicBlock | HRBottleneck],
num_blocks: List[int],
num_channels: List[int],
stride: int = 1,
norm: str = "none",
):
"""
Create one branch of the High Resolution Module.
Parameters
----------
branch_index : int
Index of the branch to create.
block : Type[HRBasicBlock | HRBottleneck]
Type of block to use in the branch (either HRBasicBlock or HRBottleneck).
num_blocks : List[int]
Number of blocks in the branch.
num_channels : List[int]
Number of output channels for the branch.
stride : int, optional
Stride of the first convolutional layer in the branch. Default is 1.
norm : str, optional
Normalization layer to use (one of 'bn', 'sync_bn', 'in', 'gn', or 'none'). Default is 'none'.
"""
downsample = None
if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
self.conv_call(
self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
self.norm_func(norm, num_channels[branch_index] * block.expansion),
)
layers = []
layers.append(
block(
conv=self.conv_call,
in_size=self.num_inchannels[branch_index],
out_size=num_channels[branch_index] * block.expansion,
stride=stride,
norm=norm,
downsample=downsample,
)
)
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
conv=self.conv_call,
in_size=self.num_inchannels[branch_index],
out_size=num_channels[branch_index],
norm=norm,
)
)
return nn.Sequential(*layers)
def _make_branches(
self,
num_branches: int,
block: Type[HRBasicBlock | HRBottleneck],
num_blocks: List[int],
num_channels: List[int],
norm: str,
):
"""
Create branches for the High Resolution Module.
Parameters
----------
num_branches : int
Number of branches to create.
block : Type[HRBasicBlock | HRBottleneck]
Type of block to use in the branches (either HRBasicBlock or HRBottleneck).
num_blocks : List[int]
Number of blocks in each branch.
num_channels : List[int]
Number of output channels for each branch.
norm : str
Normalization layer to use (one of 'bn', 'sync_bn', 'in', 'gn', or 'none').
Returns
-------
branches : nn.ModuleList
List of branches created for the High Resolution Module.
"""
branches = []
for i in range(num_branches):
branches.append(self._make_one_branch(i, block, num_blocks, num_channels, norm=norm))
return nn.ModuleList(branches)
def _make_fuse_layers(self, norm: str):
"""
Construct the fusion layers for exchanging information between branches.
These layers enable the High Resolution Module to repeatedly fuse outputs
from parallel branches at different resolutions. They consist of
convolutional operations and optional upsampling/downsampling to align
feature map dimensions for element-wise summation.
Parameters
----------
norm : str
Normalization layer type to use within the fusion layers.
Returns
-------
fuse_layers : nn.ModuleList or None
A list of lists of `nn.Module` (or `None` for identity connections)
representing the fusion operations. Returns `None` if `num_branches` is 1.
"""
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
ConvBlock(
conv=self.conv_call,
in_size=num_inchannels[j],
out_size=num_inchannels[i],
k_size=1,
padding=0,
stride=1,
norm=norm,
bias=False,
)
)
elif j == i:
fuse_layer.append(None)
else:
# Calculate true relative downsample factor
stride_j = self.branch_strides[j]
stride_i = self.branch_strides[i]
rel_stride = tuple(si // sj for si, sj in zip(stride_i, stride_j))
# Determine downsample steps dynamically
if all(s == 1 for s in rel_stride):
num_steps = 1
step_strides = [tuple([1] * len(rel_stride))]
else:
max_factor = max(rel_stride)
num_steps = 0
temp_factor = max_factor
while temp_factor > 1:
num_steps += 1
temp_factor //= 2
step_strides = []
current_rel = list(rel_stride)
for _ in range(num_steps):
s = []
for d in range(len(current_rel)):
if current_rel[d] > 1:
s.append(2)
current_rel[d] //= 2
else:
s.append(1)
step_strides.append(tuple(s))
conv3x3s = []
in_ch = num_inchannels[j]
for k in range(num_steps):
out_ch = num_inchannels[i] if k == num_steps - 1 else in_ch
_act = "none" if k == num_steps - 1 else self.activation_str
conv3x3s.append(
ConvBlock(
conv=self.conv_call,
in_size=in_ch,
out_size=out_ch,
k_size=3,
padding=1,
stride=step_strides[k],
act=_act,
norm=norm,
bias=False,
)
)
in_ch = out_ch
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
[docs]
def get_num_inchannels(self):
"""
Retrieve the current number of input channels for each branch.
This method provides access to the dynamically updated `num_inchannels`
list, which reflects the channel counts of features after they have
passed through the respective blocks within this module. This is useful
for configuring subsequent stages or modules.
Returns
-------
List[int]
A list where each element represents the number of channels for
the corresponding branch's output.
"""
return self.num_inchannels
[docs]
def forward(self, x):
"""
Perform the forward pass of the High Resolution Module.
The input is a list of tensors, where each tensor corresponds to a
feature map from a parallel resolution branch. Each feature map
first passes through its respective branch's convolutional blocks.
Then, the outputs from all branches are fused by upsampling or
downsampling as necessary, followed by element-wise summation to
create new feature maps at each target resolution.
Parameters
----------
x : List[torch.Tensor]
A list of input feature tensors, where each tensor corresponds to
a different resolution branch. The order typically goes from highest
to lowest resolution.
Returns
-------
List[torch.Tensor]
A list of output feature tensors, representing the fused and
processed features at potentially multiple scales. If
`multi_scale_output` is True, the list will contain features
for all output resolutions; otherwise, it might contain only
the highest resolution output.
"""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
if x[i].ndim == 4:
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[height_output, width_output],
mode="bilinear",
align_corners=True,
)
else:
depth_output = x[i].shape[-3]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[depth_output, height_output, width_output],
mode="trilinear",
align_corners=True,
)
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.activation(y))
return x_fuse
[docs]
class HighResolutionNet(nn.Module):
def __init__(
self,
cfg: Dict,
image_shape: Tuple[int, ...] = (256, 256, 1),
normalization: str = "none",
output_channels: List[int] = [1],
output_channel_info=["F"],
explicit_activations: bool = False,
head_activations: List[str] = ["ce_sigmoid"],
contrast: bool = False,
contrast_proj_dim: int = 256,
head_type: str = "FCN",
activation: str = "relu",
return_one_tensor: bool = False,
):
"""
Implements a 2D/3D High-Resolution Net (HRNet) model.
HRNet is a convolutional neural network architecture designed to maintain high-resolution representations throughout the network. It achieves this
by employing parallel high-to-low resolution convolution streams and repeatedly exchanging information across these streams. This design
is particularly effective for dense prediction tasks like semantic segmentation, instance segmentation, and object detection, where
preserving spatial detail is crucial.
Reference: `Deep high-resolution representation learning for visual recognition <https://ieeexplore.ieee.org/abstract/document/9052469/>`_.
Code adapted from: `Exploring Cross-Image Pixel Contrast for Semantic Segmentation <https://github.com/tfzhou/ContrastiveSeg/tree/main>`_.
Parameters
----------
cfg : Dict
HRNet configuration dictionary. Expected keys define the network structure:
* ``NUM_MODULES`` (int): Number of modules within each stage.
* ``NUM_BRANCHES`` (int): Number of parallel branches (resolution streams) in a stage.
* ``NUM_BLOCKS`` (List[int]): List specifying the number of blocks per branch.
* ``NUM_CHANNELS`` (List[int]): List specifying the number of channels for each branch.
* ``BLOCK`` (str): Type of building block, e.g., 'BASIC' for `HRBasicBlock` or 'BOTTLENECK' for `HRBottleneck`.
* ``Z_DOWN`` (bool): For 3D HRNet, whether to downsample the z-axis (True) or keep its original resolution (False).
image_shape : Tuple[int, ...]
Dimensions of the input image. E.g., `(y, x, channels)` for 2D or `(z, y, x, channels)` for 3D.
The last element `image_shape[-1]` should be the number of input channels.
normalization : str, optional
Type of normalization layer to use throughout the network. Options include
`'bn'` (Batch Normalization), `'sync_bn'` (Synchronized Batch Normalization for multi-GPU),
`'in'` (Instance Normalization), `'gn'` (Group Normalization), or `'none'`.
Defaults to "none".
output_channels : List[int], optional
Output channels of the network. If one value is provided, the model will have a single output head.
If two values are provided, the model will have two output heads (e.g. for multi-task learning with
instance segmentation and classification).
output_channel_info : list of str, optional
Information about the type of output channels. Possible values are:
- "X": where X is a letter, e.g. "F" for foreground, "D" for distance, "R" for rays, "C" for cpntours, etc.
- "class": classification (e.g. for multi-task learning)
explicit_activations : bool, optional
If True, uses explicit activation functions in the last layers.
head_activations : List[str], optional
Activation functions to apply to each output head if `explicit_activations` is True.
contrast : bool, optional
If True, an additional projection head (`ProjectionHead`) is created to generate
an embedding suitable for contrastive learning. Defaults to False.
contrast_proj_dim : int, optional
The output dimension of the projection embedding when `contrast` is True. Defaults to 256.
head_type : str, optional
Type of head to use in the module. Options are: "OCR", "FCN", "ASPP" and "PSP".
explicit_activations : bool, optional
If True, uses explicit activation functions in the last layers.
activation : str, optional
Activation function to use in the HRNet blocks. Default is "relu".
return_one_tensor : bool, optional
Whether to return a single tensor with all outputs concatenated (if False, returns a dictionary
with separate entries). Default is ``False``.
Returns
-------
model : nn.Module
The constructed HRNet model.
"""
super(HighResolutionNet, self).__init__()
if len(output_channels) == 0:
raise ValueError("'output_channels' needs to has at least one value")
if len(output_channels) > 2:
if contrast:
raise ValueError("If 'contrast' is True, 'output_channels' can only have two values at max: one for the main output and one for the class.")
if head_type != "FCN":
raise ValueError("If 'head_type' is not 'FCN', 'output_channels' can only have two values at max: one for the main output and one for the class.")
print("Selected output channels:")
for i, info in enumerate(output_channel_info):
print(f" - {i} channel for {info} output")
self.blocks_dict = {
"BASIC": HRBasicBlock,
"BOTTLENECK": HRBottleneck,
"CONVNEXT_V1": ConvNeXtBlock_V1,
"CONVNEXT_V2": ConvNeXtBlock_V2,
}
self.output_channels = output_channels
self.output_channel_info = output_channel_info
self.return_class = True if "class" in output_channel_info else False
self.in_size = 64
self.ndim = 3 if len(image_shape) == 4 else 2
self.contrast = contrast
self.head_type = head_type
self.explicit_activations = explicit_activations
self.return_one_tensor = return_one_tensor
self.activation = activation
if self.explicit_activations:
assert len(head_activations) == sum(output_channels), "If 'explicit_activations' is True, 'head_activations' needs to have the same number of values as 'output_channels'"
self.head_activations, self.class_head_activations = prepare_activation_layers(head_activations, output_channel_info, output_channels)
if self.return_class and self.class_head_activations is None:
raise ValueError("If 'return_class' is True, 'head_activations' must be provided.")
if self.ndim == 3:
self.conv_call = nn.Conv3d
self.norm_func = get_norm_3d
self.dropout = nn.Dropout3d
else:
self.conv_call = nn.Conv2d
self.norm_func = get_norm_2d
self.dropout = nn.Dropout2d
# ---------------------------------------------------------
# Dynamic Configuration Initialization
# ---------------------------------------------------------
num_stages = cfg.get("NUM_STAGES", 3)
yx_down_list = cfg.get("YX_DOWN", [2] * num_stages)
z_down_list = cfg.get("Z_DOWN", [True] * num_stages)
# Helper to safely retrieve the correct max-pooling factor per stage
def get_mpool(idx):
yx = yx_down_list[idx] if isinstance(yx_down_list, list) and idx < len(yx_down_list) else 2
z_val = z_down_list[idx] if isinstance(z_down_list, list) and idx < len(z_down_list) else z_down_list
return (z_val, yx, yx) if self.ndim == 3 else (yx, yx)
in_channels = image_shape[-1]
mpool_stem = get_mpool(0)
# Initial Stem Layers
self.conv1_block = ConvBlock(
conv=self.conv_call,
in_size=in_channels,
out_size=64,
k_size=3,
padding=1,
stride=mpool_stem,
act="none",
norm=normalization,
bias=False,
)
self.conv2_block = ConvBlock(
conv=self.conv_call,
in_size=64,
out_size=64,
k_size=3,
padding=1,
stride=mpool_stem,
act=self.activation,
norm=normalization,
bias=False,
)
self.layer1 = self._make_layer(HRBottleneck, 64, 64, 4, norm=normalization)
# ---------------------------------------------------------
# Dynamic Stage Creation
# ---------------------------------------------------------
self.transitions = nn.ModuleList()
self.stages = nn.ModuleList()
# layer1 uses HRBottleneck which expands the 64 base channels by 4 (64 * 4 = 256)
pre_stage_channels = [64 * HRBottleneck.expansion]
# Calculate absolute stride out of the stem (which downsamples twice)
stem_stride = tuple(s * s for s in mpool_stem)
current_strides = [stem_stride]
for i in range(num_stages):
mpool_stage = get_mpool(i)
b_type = cfg["BLOCK_TYPE"][i] if isinstance(cfg["BLOCK_TYPE"], list) else cfg["BLOCK_TYPE"]
block = self.blocks_dict[b_type]
cur_channels = [ch * block.expansion for ch in cfg["NUM_CHANNELS"][i]]
# Construct Transition Layer for this stage
self.transitions.append(
self._make_transition_layer(pre_stage_channels, cur_channels, norm=normalization, mpool=mpool_stage)
)
# Update absolute strides tracking for all branches
num_branches_cur = len(cur_channels)
num_branches_pre = len(current_strides)
for j in range(num_branches_cur):
if j >= num_branches_pre:
steps = j - num_branches_pre + 1
new_stride = current_strides[-1]
for _ in range(steps):
new_stride = tuple(a * b for a, b in zip(new_stride, mpool_stage))
current_strides.append(new_stride)
# Construct High Resolution Modules for this stage
stage_cfg = {
"NUM_MODULES": cfg["NUM_MODULES"][i],
"NUM_BRANCHES": cfg["NUM_BRANCHES"][i],
"NUM_BLOCKS": cfg["NUM_BLOCKS"][i],
"NUM_CHANNELS": cur_channels,
"BLOCK": b_type,
}
is_last_stage = (i == num_stages - 1)
stage, pre_stage_channels = self._make_stage(
stage_cfg, cur_channels, multi_scale_output=True, norm=normalization, branch_strides=current_strides
)
self.stages.append(stage)
# The final input channels for heads is the sum of all branch channels in the final stage
head_in_channels = sum(pre_stage_channels)
self.heads = nn.Sequential()
if head_type in ["ASPP", "PSP", "OCR"]:
if head_type == "ASPP":
self.heads.append(
ASPP(
conv=self.conv_call,
in_dims=head_in_channels,
out_dims=256,
norm=normalization,
rate=[6, 12, 18],
)
)
elif head_type == "PSP":
self.heads.append(
PSP(
conv=self.conv_call,
in_dims=head_in_channels,
out_dims=256,
norm=normalization,
pool_sizes=[1, 2, 3, 6],
)
)
elif head_type == "OCR":
self.heads.append(
OCRHead(
conv=self.conv_call,
in_dims=head_in_channels,
out_dims=256,
num_classes=self.output_channels[0],
norm=normalization,
key_dims=256,
scale=1.0,
)
)
# Add the head for classification if needed
if len(self.output_channels) > 1:
self.heads.append(self.conv_call(head_in_channels, self.output_channels[1], kernel_size=1, padding="same"))
elif head_type == "FCN":
if self.contrast:
self.heads = nn.Sequential(
self.conv_call(head_in_channels, head_in_channels, kernel_size=3, stride=1, padding=1),
self.norm_func(normalization, head_in_channels),
self.dropout(0.10),
self.conv_call(head_in_channels, self.output_channels[0], kernel_size=1, stride=1, padding=0, bias=False),
)
if len(self.output_channels) > 1:
self.heads.append(self.conv_call(head_in_channels, self.output_channels[1], kernel_size=1, padding="same"))
else:
for i, out_ch in enumerate(output_channels):
self.heads.append(self.conv_call(head_in_channels, out_ch, kernel_size=1, padding="same"))
else:
raise ValueError(f"head_type '{head_type}' is not supported. Choose from: 'ASPP', 'PSP', 'FCN'.")
if self.contrast:
self.proj_head = ProjectionHead(ndim=self.ndim, in_channels=head_in_channels, proj_dim=contrast_proj_dim)
# ---------------------------------------------------------
# Dynamic Upsample Calculation
# Branch 0's resolution is solely dictated by conv1 and conv2
# (Stem layers), applying mpool twice. We invert that here.
# ---------------------------------------------------------
if self.ndim == 2:
scale_factor = (mpool_stem[0]**2, mpool_stem[1]**2)
mode = "bilinear"
else:
scale_factor = (mpool_stem[0]**2, mpool_stem[1]**2, mpool_stem[2]**2)
mode = "trilinear"
self.upsample_logits = nn.Upsample(
scale_factor=scale_factor,
mode=mode,
align_corners=False,
)
def _make_transition_layer(
self,
num_channels_pre_layer: List[int],
num_channels_cur_layer: List[int],
norm: str,
mpool: Tuple[int, ...] = (2, 2),
):
"""
Create transition layers between stages of the HRNet.
These layers handle the transition of feature maps between stages that might
have different numbers of branches or different channel configurations.
They include convolutional blocks to adjust channels and spatial dimensions.
Parameters
----------
num_channels_pre_layer : List[int]
Number of channels in the previous layer.
num_channels_cur_layer : List[int]
Number of channels in the current layer.
norm : str
Normalization layer to use (one of 'bn', 'sync_bn', 'in', 'gn', or 'none').
mpool : Tuple[int, ...], optional
Downsampling factor for the pooling operation. Used to downsample the features. Default is (2, 2).
Returns
-------
transition_layers : nn.ModuleList
List of transition layers between the previous and current layers.
"""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
ConvBlock(
conv=self.conv_call,
in_size=num_channels_pre_layer[i],
out_size=num_channels_cur_layer[i],
k_size=3,
padding=1,
stride=1,
act=self.activation,
norm=norm,
bias=False,
)
)
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
conv3x3s.append(
ConvBlock(
conv=self.conv_call,
in_size=inchannels,
out_size=outchannels,
k_size=3,
padding=1,
stride=mpool,
act=self.activation,
norm=norm,
bias=False,
)
)
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(
self,
block: Type[HRBasicBlock | HRBottleneck],
in_size: int,
out_size: int,
blocks: int,
stride: int = 1,
norm: str = "none",
):
"""
Construct a sequential layer consisting of multiple HRNet building blocks.
This method generates a sequence of `HRBasicBlock` or `HRBottleneck` instances,
optionally including a downsampling projection for the first block if input/output
dimensions or strides differ.
Parameters
----------
block : Type[HRBasicBlock | HRBottleneck]
Type of block to use in the layer (either HRBasicBlock or HRBottleneck).
in_size : int
Number of input channels for the first block in the layer.
out_size : int
Number of output channels for the blocks in the layer.
blocks : int
Number of blocks to create in the layer.
stride : int, optional
Stride of the first convolutional layer in the layer. Default is 1.
norm : str, optional
Normalization layer to use (one of 'bn', 'sync_bn', 'in',
'gn', or 'none'). Default is 'none'. รง
Returns
-------
layer : nn.Sequential
Sequential container with the blocks of the layer.
"""
downsample = None
if stride != 1 or in_size != out_size * block.expansion:
downsample = nn.Sequential(
self.conv_call(in_size, out_size * block.expansion, kernel_size=1, stride=stride, bias=False),
self.norm_func(norm, out_size * block.expansion),
)
layers = []
layers.append(block(self.conv_call, in_size, out_size, stride, downsample=downsample, norm=norm))
in_size = out_size * block.expansion
for i in range(1, blocks):
layers.append(block(self.conv_call, in_size, out_size, norm=norm))
return nn.Sequential(*layers)
def _make_stage(
self,
layer_config,
num_inchannels,
multi_scale_output=True,
norm="none",
branch_strides=None,
):
"""
Construct a full stage of the HRNet, consisting of multiple HighResolutionModule instances.
Each stage of HRNet typically involves multiple parallel branches at different
resolutions, with information exchange between them. This method creates the modules
that manage these branches and their interactions.
Parameters
----------
layer_config : Dict
Configuration dictionary for the stage. Expected keys are:
* ``NUM_MODULES``, int: number of modules to create
* ``NUM_BRANCHES``, int: number of branches in the stage
* ``NUM_BLOCKS``, List[int]: Number of blocks per branch
* ``NUM_CHANNELS``, List[int]: Number of channels per branch
* ``BLOCK``, str: block type. Options: ['BASIC', "BOTTLENECK"]
num_inchannels : List[int]
Number of input channels for each branch in the stage.
multi_scale_output : bool, optional
Whether to output features at multiple scales or not. Default is True.
norm : str, optional
Normalization layer to use (one of 'bn', 'sync_bn', 'in',
'gn', or 'none'). Default is 'none'.
branch_strides : List[int], optional
Strides for each branch in the stage. Default is None.
Returns
-------
modules : nn.Sequential
Sequential container with the modules of the stage.
num_inchannels : List[int]
Number of input channels for the next stage.
"""
num_modules = layer_config["NUM_MODULES"]
num_branches = layer_config["NUM_BRANCHES"]
num_blocks = layer_config["NUM_BLOCKS"]
num_channels = layer_config["NUM_CHANNELS"]
block = self.blocks_dict[layer_config["BLOCK"]]
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(
ndim=self.ndim,
num_branches=num_branches,
blocks=block,
num_blocks=num_blocks,
num_inchannels=num_inchannels,
num_channels=num_channels,
multi_scale_output=reset_multi_scale_output,
norm=norm,
branch_strides=branch_strides,
activation=self.activation,
)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
[docs]
def forward(self, input) -> Dict | torch.Tensor:
"""
Perform the forward pass of the HighResolutionNet.
The input `x` first goes through initial convolutional blocks. Then, it
propagates through a series of HRNet stages, where feature maps are
processed in parallel across multiple resolutions and information is
exchanged. Finally, features from all resolutions are fused, and passed
through a final prediction head. Optionally, a contrastive learning
projection head and/or a multi-head classification output can be included.
Parameters
----------
x : torch.Tensor
The input image tensor.
Expected shape for 2D: `(batch_size, channels, height, width)`.
Expected shape for 3D: `(batch_size, channels, depth, height, width)`.
Returns
-------
Dict or torch.Tensor
If there is only one output head, returns a tensor with the predictions.
If there are multiple output heads (e.g. for multi-task learning), returns a dictionary with keys:
- "pred": tensor with the main predictions (e.g. segmentation map)
- "class": tensor with the classification output (if `return_class` is True)
- "embed": tensor with the contrastive learning embedding (if `contrast` is True)
"""
x = self.conv1_block(input)
x = self.conv2_block(x)
x = self.layer1(x)
y_list = [x]
# ---------------------------------------------------------
# Dynamic Forward Pass through stages
# ---------------------------------------------------------
for i in range(len(self.stages)):
x_list = []
transition = self.transitions[i]
stage = self.stages[i]
num_branches = len(transition)
for j in range(num_branches):
if transition[j] is not None:
# Modify existing branch
if j < len(y_list):
x_list.append(transition[j](y_list[j]))
# Generate new branch from lowest resolution branch
else:
x_list.append(transition[j](y_list[-1]))
else:
x_list.append(y_list[j])
y_list = stage(x_list)
# Check drop_stage4 dynamically on the second-to-last stage
if os.environ.get("drop_stage4") and i == len(self.stages) - 2:
return y_list
feat1 = y_list[0]
feats_to_cat = [feat1]
if feat1.ndim == 4:
target_size = (feat1.shape[2], feat1.shape[3])
mode = "bilinear"
else:
target_size = (feat1.shape[2], feat1.shape[3], feat1.shape[4])
mode = "trilinear"
for i in range(1, len(y_list)):
feats_to_cat.append(F.interpolate(y_list[i], size=target_size, mode=mode, align_corners=True))
feats = torch.cat(feats_to_cat, dim=1)
out = self.heads(feats)
out = self.upsample_logits(out)
out_dict = {}
# Pass the features through the output heads
class_outs, outs = [], []
for i, head in enumerate(self.heads):
if "class" not in self.output_channel_info[i]:
outs.append(head(feats))
else:
class_outs.append(head(feats))
outs = torch.cat(outs, dim=1)
# Apply activations to the output heads if explicit_activations is True
if self.explicit_activations:
# If there is only one activation, apply it to the whole tensor
if len(self.head_activations) == 1:
outs = self.head_activations[0](outs)
else:
for i, act in enumerate(self.head_activations):
outs[:, i:i+1] = act(outs[:, i:i+1])
if self.return_class and self.class_head_activations is not None:
for i, act in enumerate(self.class_head_activations):
class_outs[i] = act(class_outs[i])
out_dict = {
"pred": outs,
}
if self.return_class:
out_dict["class"] = torch.cat(class_outs, dim=1)
# Contrastive learning head
if self.contrast:
out_dict["embed"] = self.proj_head(feats)
if len(out_dict.keys()) == 1:
return out_dict["pred"]
else:
if self.return_one_tensor:
if "class" in out_dict:
return torch.cat((out_dict["pred"], torch.argmax(out_dict["class"], dim=1).unsqueeze(1)), dim=1)
else:
return out_dict["pred"]
return out_dict