"""
Metrics and loss functions for BiaPy.
This module provides a variety of metrics and loss functions for evaluating and training
deep learning models in BiaPy. It includes implementations for Jaccard index (IoU),
Dice loss, BCE, Cross-Entropy, contrastive losses, instance segmentation losses,
detection metrics, and wrappers for SSIM, MSE, and MAE-based losses. Both PyTorch and
NumPy-based metrics are supported for 2D and 3D biomedical image analysis.
"""
import torch
import numpy as np
import pandas as pd
from scipy.spatial import distance_matrix
from scipy.optimize import linear_sum_assignment
from torchmetrics import JaccardIndex
from torchmetrics.image import StructuralSimilarityIndexMeasure
from pytorch_msssim import SSIM
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional, List, Tuple, Dict, Union
[docs]
def jaccard_index_numpy(y_true, y_pred):
"""
Compute the Jaccard index (Intersection over Union) between ground truth and prediction.
Parameters
----------
y_true : N dim Numpy array
Ground truth masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or
``(volume_number, z, x, y, channels)`` for 3D volumes.
y_pred : N dim Numpy array
Predicted masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or
``(volume_number, z, x, y, channels)`` for 3D volumes.
Returns
-------
jac : float
Jaccard index value.
"""
if y_true.ndim != y_pred.ndim:
raise ValueError("Dimension mismatch: {} and {} provided".format(y_true.shape, y_pred.shape))
TP = np.count_nonzero(y_pred * y_true)
FP = np.count_nonzero(y_pred * (y_true - 1))
FN = np.count_nonzero((y_pred - 1) * y_true)
if (TP + FP + FN) == 0:
jac = 0
else:
jac = TP / (TP + FP + FN)
return jac
[docs]
def jaccard_index_numpy_without_background(y_true, y_pred):
"""
Compute Jaccard index excluding the background class (first channel).
Parameters
----------
y_true : N dim Numpy array
Ground truth masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or
``(volume_number, z, x, y, channels)`` for 3D volumes.
y_pred : N dim Numpy array
Predicted masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or
``(volume_number, z, x, y, channels)`` for 3D volumes.
Returns
-------
jac : float
Jaccard index value.
"""
if y_true.ndim != y_pred.ndim:
raise ValueError("Dimension mismatch: {} and {} provided".format(y_true.shape, y_pred.shape))
TP = np.count_nonzero(y_pred[..., 1:] * y_true[..., 1:])
FP = np.count_nonzero(y_pred[..., 1:] * (y_true[..., 1:] - 1))
FN = np.count_nonzero((y_pred[..., 1:] - 1) * y_true[..., 1:])
if (TP + FP + FN) == 0:
jac = 0
else:
jac = TP / (TP + FP + FN)
return jac
[docs]
def weight_binary_ratio(target):
"""
Compute a weight map to balance foreground and background pixels.
Parameters
----------
target : torch.Tensor
Target tensor.
Returns
-------
weight : torch.Tensor
Weight map.
"""
if torch.max(target) == torch.min(target):
return torch.ones_like(target, dtype=torch.float32)
# Generate weight map by balancing the foreground and background.
min_ratio = 5e-2
label = target.clone() # copy of target label
label = (label != 0).double() # foreground
ww = label.sum() / torch.prod(torch.tensor(label.shape, dtype=torch.double))
ww = torch.clamp(ww, min=min_ratio, max=1 - min_ratio)
weight_factor = max(ww, 1 - ww) / min(ww, 1 - ww) # type: ignore
# Case 1 -- Affinity Map
# In that case, ww is large (i.e., ww > 1 - ww), which means the high weight
# factor should be applied to background pixels.
# Case 2 -- Contour Map
# In that case, ww is small (i.e., ww < 1 - ww), which means the high weight
# factor should be applied to foreground pixels.
if ww > 1 - ww:
# Switch when foreground is the dominant class.
label = 1 - label
weight = weight_factor * label + (1 - label)
return weight.float()
[docs]
class jaccard_index:
"""
Jaccard index (IoU) metric for PyTorch tensors.
Supports binary and multiclass segmentation, with optional thresholding and ignore index.
"""
def __init__(
self,
num_classes: int,
device: torch.device,
t: float = 0.5,
model_source: str = "biapy",
ndim: int = 2,
ignore_index: int = -1,
):
"""
Define Jaccard index.
Parameters
----------
num_classes : int
Number of classes.
device : Torch device
Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps",
"xpu", "xla" or "meta".
t : float, optional
Threshold to be applied.
model_source : str, optional
Source of the model. It can be "biapy", "bmz" or "torchvision".
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes.
ignore_index : int, optional
Value to ignore in the loss calculation. If not provided, no value will be ignored.
"""
self.model_source = model_source
self.loss = torch.nn.CrossEntropyLoss()
self.device = device
self.num_classes = num_classes
self.t = t
self.ndim = ndim
self.ignore_index = ignore_index if ignore_index != -1 else None
if self.num_classes > 2:
self.jaccard = JaccardIndex(
task="multiclass", threshold=self.t, num_classes=self.num_classes, ignore_index=self.ignore_index
).to(self.device, non_blocking=True)
else:
self.jaccard = JaccardIndex(
task="binary", threshold=self.t, num_classes=self.num_classes, ignore_index=self.ignore_index
).to(self.device, non_blocking=True)
def __call__(self, y_pred, y_true):
"""
Calculate Jaccard index (intersection over union).
Parameters
----------
y_true : torch.Tensor
Ground truth masks.
y_pred : torch.Tensor
Predicted masks.
Returns
-------
jaccard : torch.Tensor
Jaccard index value.
"""
_y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred
# For those cases that are predicting 2 channels (binary case) we adapt the GT to match.
# It's supposed to have 0 value as background and 1 as foreground
if self.model_source == "bmz" and self.num_classes <= 2 and _y_pred.shape[1] != y_true.shape[1]:
y_true = torch.cat((1 - y_true, y_true), 1)
if not isinstance(_y_pred, list):
_y_pred = [_y_pred]
iou = 0
for j, pd in enumerate(_y_pred):
_y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true
if self.num_classes > 2:
if pd.shape[1] > 1:
_y_true = _y_true.squeeze()
if len(pd.shape) - 2 == len(_y_true.shape):
_y_true = _y_true.unsqueeze(0)
iou += self.jaccard(pd, _y_true.long() if _y_true.is_floating_point() else _y_true)
return iou/len(_y_pred)
[docs]
class multiple_metrics:
"""
Compute multiple metrics for instance segmentation workflows.
Supports IoU, L1, and other metrics for multi-head or multi-channel outputs.
"""
def __init__(
self,
num_classes: int,
metric_names: List[str],
device: torch.device,
out_channels: Optional[List[str]]=["F"],
channel_extra_opts: Optional[Dict]={},
ignore_index: int = -1,
model_source: str = "biapy",
ndim: int = 2,
):
"""
Define instance segmentation workflow metrics.
Parameters
----------
num_classes : int
Number of classes.
metric_names : list of str
Names of the metrics to use.
device : Torch device
Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps",
"xpu", "xla" or "meta".
out_channels : list of str, optional
Output channels to be predicted. E.g. ["F", "C"] for foreground and class channels.
channel_extra_opts : dict, optional
Additional options for each output channel (e.g., {"B": {"mask_values": True}}).
ignore_index : int, optional
Value to ignore in the loss calculation. If not provided, no value will be ignored.
model_source : str, optional
Source of the model. It can be "biapy", "bmz" or "torchvision".
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes.
"""
self.num_classes = num_classes
self.metric_names = metric_names
self.device = device
self.out_channels = out_channels.copy() if out_channels is not None else [".",]*len(metric_names)
if self.num_classes > 2:
self.out_channels += ["class"]
self.out_channels = [x for x in self.out_channels if x != "We"] # Ignore weight extra channel
self.channel_extra_opts = channel_extra_opts
self.model_source = model_source
self.ignore_index = ignore_index if ignore_index != -1 else None
self.ndim = ndim
self.metric_func = []
for i in range(len(metric_names)):
if "IoU (classes)" in metric_names[i]:
loss_func = JaccardIndex(
task="multiclass", threshold=0.5, num_classes=self.num_classes, ignore_index=self.ignore_index
).to(self.device, non_blocking=True)
elif "IoU" in metric_names[i]:
loss_func = JaccardIndex(
task="binary", threshold=0.5, num_classes=2, ignore_index=self.ignore_index
).to(self.device, non_blocking=True)
elif "L1" in metric_names[i]:
loss_func = torch.nn.L1Loss()
else:
raise ValueError(f"Metric {metric_names[i]} not recognized.")
self.metric_func.append(loss_func)
def __call__(self, y_pred, y_true):
"""
Calculate metrics.
Parameters
----------
y_true : torch.Tensor
Ground truth masks.
y_pred : torch.Tensor or list of Tensors
Prediction.
Returns
-------
dict : dict
Metrics and their values.
"""
if isinstance(y_pred, dict):
_y_pred = y_pred["pred"]
else:
_y_pred = y_pred
# Check multi-head
if isinstance(y_pred, dict) and "class" in y_pred:
_y_pred_class = torch.argmax(y_pred["class"], dim=1)
else:
# Just take the last channel as class prediction from the first output, which is assumed to be the main one
if isinstance(_y_pred, list):
_y_pred_class = _y_pred[0][:, -1]
else:
_y_pred_class = _y_pred[:, -1]
if not isinstance(_y_pred, list):
_y_pred = [_y_pred]
res_metrics = {}
for pd in _y_pred:
_y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true
db_val_type = ""
for pred_ch_start, channel in enumerate(self.out_channels):
gt_ch_start = pred_ch_start
if channel == "A":
pred_ch_end = pd.shape[1]
gt_ch_end = pred_ch_end
elif channel == "R":
assert self.channel_extra_opts is not None and "R" in self.channel_extra_opts, "Rays channel options must be provided."
pred_ch_end = self.channel_extra_opts["R"].get("nrays", 32) + pred_ch_start
gt_ch_end = pred_ch_end
elif channel == "Db":
assert self.channel_extra_opts is not None and "Db" in self.channel_extra_opts, "Distance to border channel options must be provided."
db_val_type = self.channel_extra_opts.get("Db", {}).get("val_type", "norm")
if db_val_type == "discretize":
db_dis_bin_size = self.channel_extra_opts.get("Db", {}).get("bin_size", 0.1)
db_dis_K = int(round(1.0 / db_dis_bin_size)) # 10
db_channels = db_dis_K + 1
else:
db_channels = 1
pred_ch_end = pred_ch_start + db_channels
gt_ch_end = pred_ch_end
else:
pred_ch_end = pred_ch_start + 1
gt_ch_end = pred_ch_end
if self.metric_names[pred_ch_start] not in res_metrics:
res_metrics[self.metric_names[pred_ch_start]] = []
# Measure metric
if self.metric_names[pred_ch_start] == "IoU (classes)":
res_metrics[self.metric_names[pred_ch_start]].append(self.metric_func[pred_ch_start](_y_pred_class, _y_true[:, 1]))
else:
y_pred_slice = pd[:, pred_ch_start:pred_ch_end]
y_true_slice = _y_true[:, gt_ch_start:gt_ch_end].float()
if y_pred_slice.shape[1] != y_true_slice.shape[1] and "Db" == channel and db_val_type == "discretize":
y_pred_slice = torch.argmax(y_pred_slice, dim=1).unsqueeze(1).float()
y_true_slice = y_true_slice.float()
res_metrics[self.metric_names[pred_ch_start]].append(self.metric_func[pred_ch_start](y_pred_slice, y_true_slice))
# Mean of same metric values
for key, value in res_metrics.items():
if len(value) > 1:
res_metrics[key] = torch.mean(torch.as_tensor(value))
else:
res_metrics[key] = torch.as_tensor(value[0])
return res_metrics
[docs]
def scale_target(targets_: torch.Tensor, scaled_size: Tuple[int, ...]) -> torch.Tensor:
"""
Scale the target masks to match the size of the predictions.
Parameters
----------
targets_ : torch.Tensor
Ground truth masks.
scaled_size : tuple
Size to scale the masks to.
Returns
-------
targets : torch.Tensor
Scaled ground truth masks.
"""
targets = F.interpolate(targets_.clone(), size=scaled_size, mode="nearest")
return targets
[docs]
class loss_encapsulation(nn.Module):
"""Just a wrapper to any other common loss deataching the prediction from the dict given by the model."""
def __init__(self, loss):
"""
Initialize the loss_encapsulation module.
Parameters
----------
loss : nn.Module or callable
The loss function to wrap.
"""
super(loss_encapsulation, self).__init__()
self.loss = loss
[docs]
def forward(self, inputs, targets):
"""
Forward pass for the encapsulated loss.
Parameters
----------
inputs : torch.Tensor or dict
Model predictions. If a dict, expects the prediction under the "pred" key.
targets : torch.Tensor
Ground truth targets.
Returns
-------
loss : torch.Tensor
Computed loss value.
"""
if isinstance(inputs, dict):
inputs = inputs["pred"]
return self.loss(inputs, targets)
[docs]
class CrossEntropyLoss_wrapper:
"""
Wrapper for PyTorch's CrossEntropyLoss and BCEWithLogitsLoss with support for class rebalancing.
"""
def __init__(
self,
num_classes: int,
ndim: int = 2,
class_rebalance: str = "none",
class_weights: List[float] = [],
ignore_index: int = -1,
device=None,
):
"""
Initialize wrapper to Pytorch's CrossEntropyLoss.
Parameters
----------
num_classes : int
Number of classes.
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes.
class_rebalance: str, optional
Whether to reweight classes or not. Options are: "none" and "manual".
class_weights : list of float, optional
List of weights for each class to be used in "manual" class rebalancing. E.g. ``[0.7, 0.3]`` for 2 classes.
ignore_index : int, optional
Value to ignore in the loss calculation. If not provided, no value will be ignored.
device : Torch device, optional
Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps".
"""
self.ndim = ndim
self.num_classes = num_classes
self.class_rebalance = class_rebalance
self.class_weights = None
self.ignore_index = ignore_index if ignore_index != -1 else -100 # Default ignore index for CrossEntropyLoss
self.device = device if device is not None else torch.device("cpu")
# For intermediate outputs weighting
self.gamma = 0.5
if self.class_rebalance == "manual":
self.class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32)
if num_classes <= 2:
self.loss = torch.nn.BCEWithLogitsLoss(weight=self.class_weights)
else:
self.loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights)
def __call__(self, y_pred, y_true):
"""
Calculate CrossEntropyLoss.
Parameters
----------
y_true : torch.Tensor
Ground truth masks.
y_pred : torch.Tensor
Predicted masks.
Returns
-------
loss : torch.Tensor
Loss value.
"""
_y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred
if not isinstance(_y_pred, list):
_y_pred = [_y_pred]
inter_output_weights = [1.0]
else:
w = [self.gamma**i for i in range(len(_y_pred))]
s = sum(w)
inter_output_weights = [x / s for x in w]
loss = 0
for j, pd in enumerate(_y_pred):
_y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true
if self.num_classes <= 2:
_loss = self.loss(pd, _y_true.type(torch.float32))
else:
_loss = self.loss(pd, _y_true[:, 0].type(torch.long))
loss += _loss * inter_output_weights[j]
return loss
[docs]
class detection_loss:
"""
Loss for detection models with separated_class_channel. Combines BCE for the foreground and CrossEntropy for the class channel,
with optional class rebalancing and channel weighting.
"""
def __init__(
self,
ndim: int = 2,
class_rebalance_within_channels: bool = True,
separated_class_channel: bool = False,
channel_weights = (1, 1),
class_rebalance: str = "none",
class_weights: List[float] = [],
ignore_index: int = -1,
device=None,
):
"""
Initialize detection loss.
Parameters
----------
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes.
class_rebalance_within_channels : bool, optional
Whether to apply a rebalancing strategy to the loss function to give more importance to underrepresented pixels within the channels.
The weights are calculated automatically based on the number of pixels of each class. In the specific case of detection, where there
are usually much less pixels representing the center of the objects to detect than background pixels, with this option activated,
the loss will give more importance to the pixels representing the center of the objects to help the model learn better to predict them.
separated_class_channel : bool, optional
When a separated class channel is expected in the predictions e.g. points + classification in detection.
channel_weights : 2 float tuple, optional
Weights to be applied to each channel of the data, i.e., centroid detection and class. E.g. ``(1, 0.2)``.
This only works if ``separated_class_channel`` is ``True``.
class_rebalance: str, optional
Whether to reweight classes or not. This only works if ``separated_class_channel`` is ``True``.
Options are: "none" and ``"manual"``.
class_weights : list of float, optional
List of weights for each class to be used in ``"manual"`` class rebalancing. This only works when ``separated_class_channel`` is ``True`` and
``class_rebalance`` is ``"manual"``. E.g. ``[1, 1.7, 0.5]`` for 3 classes.
ignore_index : int, optional
Value to ignore in the loss calculation. If not provided, no value will be ignored.
device : Torch device, optional
Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps".
"""
if separated_class_channel:
if class_rebalance == "manual" and not class_weights:
raise ValueError("class_weights must be provided when class_rebalance is 'manual'")
self.channel_weights = channel_weights
if len(self.channel_weights) != 2:
raise ValueError("channel_weights must be a tuple of 2 float values when separated_class_channel is True")
else:
self.channel_weights = (1, 0)
self.ndim = ndim
self.separated_class_channel = separated_class_channel
self.class_rebalance_within_channels = class_rebalance_within_channels
self.class_rebalance = class_rebalance
self.class_weights = None
self.ignore_index = ignore_index if ignore_index != -1 else -100 # Default ignore index for CrossEntropyLoss
self.device = device if device is not None else torch.device("cpu")
self.gamma = 0.5
if self.class_rebalance == "manual":
self.class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32)
self.centroid_loss = torch.nn.BCEWithLogitsLoss()
if self.separated_class_channel:
self.class_channel_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights, reduction="none")
def __call__(self, y_pred, y_true):
"""
Calculate CrossEntropyLoss.
Parameters
----------
y_true : torch.Tensor
Ground truth masks.
y_pred : torch.Tensor
Predicted masks.
Returns
-------
loss : torch.Tensor
Loss value.
"""
if self.separated_class_channel:
_y_pred = y_pred["pred"]
_y_pred_class = y_pred["class"]
assert (
y_true.shape[1] == 2
), f"In separated_class_channel setting the ground truth is expected to have 2 channels. Provided {y_true.shape}"
else:
_y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred
if not isinstance(_y_pred, list):
_y_pred = [_y_pred]
_y_pred_class = [_y_pred_class] if self.separated_class_channel else None
inter_output_weights = [1.0]
else:
w = [self.gamma**i for i in range(len(_y_pred))]
s = sum(w)
inter_output_weights = [x / s for x in w]
loss = 0
for j, pd in enumerate(_y_pred):
_y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true
if self.class_rebalance_within_channels:
weight_mask = weight_binary_ratio(_y_true[:, 0])
loss_fn = torch.nn.BCEWithLogitsLoss(weight=weight_mask)
else:
loss_fn = self.centroid_loss
_loss = self.channel_weights[0] * loss_fn(pd[:, 0], _y_true[:, 0].type(torch.float32))
if self.separated_class_channel:
mask = (_y_true[:, 0] != 0).float()
if isinstance(self.ignore_index, (int, float)):
mask = mask * (_y_true[:, 0] != self.ignore_index).float()
loss_cls = self.class_channel_loss(_y_pred_class[j], _y_true[:, -1].type(torch.long))
loss_cls = loss_cls * mask
loss_cls = loss_cls.sum() / mask.sum().clamp_min(1.0)
_loss += self.channel_weights[1] * loss_cls
loss += _loss * inter_output_weights[j]
return loss
[docs]
class DiceLoss(nn.Module):
def __init__(self, batch_dice: bool = True, smooth: float = 1e-5):
super().__init__()
self.batch_dice = batch_dice
self.smooth = smooth
[docs]
def forward(self, y_pred, y_true):
# 1. Apply Sigmoid or Softmax to predictions to get probabilities
if y_pred.shape[1] == 1:
y_pred = torch.sigmoid(y_pred)
else:
y_pred = torch.softmax(y_pred, dim=1)
# Ensure y_true is one-hot encoded if predicting multiple classes
if y_pred.shape != y_true.shape:
y_true = torch.nn.functional.one_hot(y_true.to(torch.int64), num_classes=y_pred.shape[1])
y_true = y_true.permute(0, 4, 1, 2, 3).squeeze(1) # Adjust permute based on 2D/3D
y_true = y_true.float()
# 2. Decide which axes to sum over.
# For a tensor [Batch, Channels, Height, Width] (2D) or [B, C, D, H, W] (3D):
# If batch_dice is True, we sum over the Batch dimension (0) AND the spatial dimensions.
# If False, we ONLY sum over the spatial dimensions, leaving Batch intact.
axes = list(range(2, len(y_pred.shape))) # e.g., [2, 3] for 2D
if self.batch_dice:
axes = [0] + axes # e.g., [0, 2, 3] for 2D
# 3. Calculate Intersection and Union
intersection = torch.sum(y_pred * y_true, dim=axes)
union = torch.sum(y_pred, dim=axes) + torch.sum(y_true, dim=axes)
# 4. Compute Dice
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
# 5. Return loss (1 - Dice). Mean across remaining dimensions (Channels/Batches)
return 1.0 - torch.mean(dice)
[docs]
class DiceCELoss(nn.Module):
"""
Combines Cross Entropy (or Binary Cross Entropy) and Dice Loss. Supports multi-head (separated_class_channel), deep supervision lists,
class rebalancing, and channel weighting.
"""
def __init__(
self,
num_classes: int,
ndim: int = 2,
separated_class_channel: bool = False,
batch_dice: bool = True,
smooth: float = 1e-5,
model_source: str = "biapy",
class_rebalance: str = "none",
class_weights: List[float] = [],
channel_weights: Tuple[float, float] = (1.0, 1.0),
w_ce: float = 1.0,
w_dice: float = 1.0,
ignore_index: int = -1,
device=None,
):
"""
Initialize DiceCELoss.
Parameters
----------
num_classes : int
Number of classes.
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes.
separated_class_channel : bool, optional
For separated_class_channel predictions e.g. points + classification in detection.
batch_dice : bool, optional
Whether to calculate Dice loss across the batch dimension or not.
smooth : float, optional
Smoothing factor to avoid division by zero in Dice loss.
model_source : str, optional
Source of the model. It can be "biapy", "bmz" or "torchvision".
class_rebalance: str, optional
Whether to reweight classes (inside loss function) or not. Options are: "none", "auto" and "manual".
class_weights : list of float, optional
List of weights for each class to be used in "manual" class rebalancing. E.g. ``[0.7, 0.3]`` for 2 classes.
channel_weights : tuple of float, optional
Weights to be applied to segmentation (binary and contours) and to distances respectively. E.g. ``(1, 0.2)``, ``1`` should be multipled by
``BCE`` for the first two channels and ``0.2`` to ``MSE`` for the last channel.
w_ce : float, optional
Weight for the Cross Entropy component of the loss.
w_dice : float, optional
Weight for the Dice component of the loss.
ignore_index : int, optional
Value to ignore in the loss calculation. If not provided, no value will be ignored.
device : Torch device, optional
Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", "xpu", "xla" or "meta".
"""
super(DiceCELoss, self).__init__()
self.num_classes = num_classes
self.ndim = ndim
self.separated_class_channel = separated_class_channel
self.batch_dice = batch_dice
self.smooth = smooth
self.model_source = model_source
self.class_rebalance = class_rebalance
self.channel_weights = channel_weights
self.w_ce = w_ce
self.w_dice = w_dice
self.ignore_index = ignore_index if ignore_index != -1 else -100
self.device = device if device is not None else torch.device("cpu")
self.gamma = 0.5 # For intermediate outputs weighting
self.class_weights_tensor = None
if class_weights is not None and len(class_weights) > 0:
self.class_weights_tensor = torch.tensor(class_weights, device=self.device, dtype=torch.float32)
# Initialize standard CE/BCE
if num_classes <= 2:
self.loss = torch.nn.BCEWithLogitsLoss(weight=self.class_weights_tensor)
else:
self.loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights_tensor)
# Initialize Classification Head Loss
if self.separated_class_channel:
self.class_channel_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)
def _compute_dice(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""Helper function to compute Dice loss on spatial outputs."""
if self.num_classes <= 2:
pred_probs = torch.sigmoid(y_pred)
dice_target = y_true.view_as(pred_probs).float()
else:
pred_probs = torch.softmax(y_pred, dim=1)
if y_true.ndim == y_pred.ndim:
y_true_squeezed = y_true.squeeze(1).long()
else:
y_true_squeezed = y_true.long()
dice_target = F.one_hot(y_true_squeezed, num_classes=self.num_classes)
if y_pred.ndim == 4: # 2D: (B, C, H, W)
dice_target = dice_target.permute(0, 3, 1, 2)
elif y_pred.ndim == 5: # 3D: (B, C, D, H, W)
dice_target = dice_target.permute(0, 4, 1, 2, 3)
dice_target = dice_target.float()
axes = list(range(2, y_pred.ndim))
if self.batch_dice:
axes = [0] + axes
intersection = torch.sum(pred_probs * dice_target, dim=axes)
union = torch.sum(pred_probs, dim=axes) + torch.sum(dice_target, dim=axes)
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
return 1.0 - torch.mean(dice)
[docs]
def forward(self, y_pred: Union[torch.Tensor, dict, List], y_true: torch.Tensor) -> torch.Tensor:
"""
Calculates Dice + CE Loss mimicking CrossEntropyLoss_wrapper structure.
Parameters
----------
y_pred : torch.Tensor, dict, or list of Tensors
Model predictions. Can be a single tensor, a dict with "pred" and optionally "
class" keys, or a list of tensors for deep supervision.
y_true : torch.Tensor
Ground truth masks.
Returns
-------
loss : torch.Tensor
Combined Dice + CE loss value.
"""
# 1. Extract predictions dict if needed
if self.separated_class_channel:
_y_pred = y_pred["pred"]
_y_pred_class = y_pred["class"]
assert (
y_true.shape[1] == 2
), f"In separated_class_channel setting the ground truth is expected to have 2 channels. Provided {y_true.shape}"
else:
_y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred
# 2. Handle specific model source adaptations
if self.model_source == "bmz" and self.num_classes <= 2 and _y_pred.shape[1] != y_true.shape[1]:
y_true = torch.cat((1 - y_true, y_true), 1)
# 3. Handle Lists for Deep Supervision
if not isinstance(_y_pred, list):
_y_pred = [_y_pred]
_y_pred_class = [_y_pred_class] if self.separated_class_channel else None
inter_output_weights = [1.0]
else:
w = [self.gamma**i for i in range(len(_y_pred))]
s = sum(w)
inter_output_weights = [x / s for x in w]
# 4. Loop over outputs and calculate combined loss
loss = 0
for j, pd in enumerate(_y_pred):
# Assumes `scale_target` is defined globally in your environment
_y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true
# Determine CE loss function (handles 'auto' rebalancing)
if self.class_rebalance == "auto":
if self.separated_class_channel:
weight_mask = weight_binary_ratio(_y_true[:, 0]) # Assumes global function
loss_fn = torch.nn.BCEWithLogitsLoss(weight=weight_mask)
else:
if self.num_classes <= 2:
weight_mask = weight_binary_ratio(_y_true) # Assumes global function
loss_fn = torch.nn.BCEWithLogitsLoss(weight=weight_mask)
else:
loss_fn = self.loss
else:
loss_fn = self.loss
# Calculate Loss based on routing
if self.separated_class_channel:
# Spatial branch (CE + Dice)
ce_spatial = loss_fn(pd[:, 0], _y_true[:, 0].float())
dice_spatial = self._compute_dice(pd[:, 0].unsqueeze(1), _y_true[:, 0].unsqueeze(1))
combined_spatial_loss = (self.w_ce * ce_spatial) + (self.w_dice * dice_spatial)
# Classification branch (CE only)
class_loss = self.class_channel_loss(_y_pred_class[j], _y_true[:, -1].type(torch.long))
# Apply channel weights
_loss = (self.channel_weights[0] * combined_spatial_loss) + (self.channel_weights[1] * class_loss)
else:
if self.num_classes <= 2:
ce_l = loss_fn(pd, _y_true.type(torch.float32))
else:
ce_l = loss_fn(pd, _y_true[:, 0].type(torch.long))
dice_l = self._compute_dice(pd, _y_true)
_loss = (self.w_ce * ce_l) + (self.w_dice * dice_l)
loss += _loss * inter_output_weights[j]
return loss
[docs]
class ContrastCELoss(nn.Module):
"""
Contrastive Cross Entropy Loss for semantic segmentation tasks. It mixes the main loss function and the constrastive loss.
Parameters
----------
main_loss : nn.Module
The main loss function to be used for the segmentation task.
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. Default is 2.
weight : float, optional
Weight for the contrastive loss. Default is 1.0.
This weight is used to balance the contribution of the contrastive loss in the final loss calculation
and can be adjusted based on the specific requirements of the task.
ignore_index : int, optional
Label to ignore in the loss calculation. Default is -1.
"""
def __init__(
self,
main_loss: nn.Module,
ndim: int = 2,
weight: float = 1.0,
ignore_index: int = -1,
):
"""
Initialize the ContrastCELoss module.
Parameters
----------
main_loss : nn.Module
The main loss function to be used for the segmentation task.
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. Default is 2.
weight : float, optional
Weight for the contrastive loss. Default is 1.0.
ignore_index : int, optional
Label to ignore in the loss calculation. Default is -1.
"""
super(ContrastCELoss, self).__init__()
self.ndim = ndim
self.main_loss = main_loss
self.contrast_criterion = PixelContrastLoss(ignore_index=ignore_index, ndim=ndim)
self.loss_weight = weight
[docs]
def forward(self, preds, target, with_embed=False):
"""
Forward pass of the Contrastive Cross Entropy Loss.
Parameters
----------
preds : dict
Dictionary containing the predictions from the model. It should contain:
- "pred": Segmentation predictions.
- "embed": Embedding predictions.
- "segment_queue": Segment queues for contrastive learning.
- "pixel_queue": Pixel queues for contrastive learning.
target : torch.Tensor
Ground truth segmentation masks.
with_embed : bool, optional
Whether to include the embedding in the loss calculation. Default is False.
"""
assert "pred" in preds, "Segmentation prediction is missing in the input dictionary."
assert "embed" in preds, "Embedding prediction is missing in the input dictionary."
seg = preds["pred"]
embedding = preds["embed"]
segment_queue = preds["segment_queue"] if "segment_queue" in preds else None
pixel_queue = preds["pixel_queue"] if "pixel_queue" in preds else None
if seg.shape[-self.ndim :] != target.shape[-self.ndim :]:
mode = "bilinear" if self.ndim == 2 else "trilinear"
pred = F.interpolate(input=seg, size=target.shape[-self.ndim :], mode=mode, align_corners=True)
else:
pred = seg
loss = self.main_loss(pred, target)
loss_contrast = 0
if segment_queue is not None and pixel_queue is not None:
queue = torch.cat((segment_queue, pixel_queue), dim=1)
# When the classes are less or equal 2 the background class channel is not added in BiaPy
# so can't apply directly an argmax/max operation
if seg.shape[1] <= 2:
_, predict = seg.max(dim=1)
if predict.ndim == 3:
offsets = torch.tensor([1, 2], device=seg.device).view(1, 2, 1, 1)
else:
offsets = torch.tensor([1, 2], device=seg.device).view(1, 2, 1, 1, 1)
predict = predict * offsets
predict, _ = predict.max(dim=1)
else:
predict = torch.argmax(seg, 1)
loss_contrast += self.contrast_criterion(
embedding,
labels=target,
predict=predict,
queue=queue,
)
else:
loss_contrast += 0
if with_embed:
return loss + self.loss_weight * loss_contrast
return loss + 0 * loss_contrast # just a trick to avoid errors in distributed training
[docs]
class PixelContrastLoss(nn.Module):
"""
Pixel Contrastive Loss for semantic segmentation tasks.
Supports hard anchor sampling and negative sampling for contrastive learning.
"""
def __init__(
self,
temperature: float = 0.07,
base_temperature: float = 0.07,
ignore_index: int = -1,
max_samples: int = 1024,
max_views: int = 1,
ndim: int = 2,
):
"""
Initialize the Pixel Contrastive Loss for semantic segmentation tasks.
Parameters
----------
temperature : float, optional
Temperature parameter for the contrastive loss. Default is 0.07.
base_temperature : float, optional
Base temperature for the contrastive loss. Default is 0.07.
ignore_index : int, optional
Label to ignore in the loss calculation. Default is -1.
max_samples : int, optional
Maximum number of samples to consider for the contrastive loss. Default is 1024.
max_views : int, optional
Maximum number of views to consider for the contrastive loss. Default is 1.
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. Default is 2.
"""
super(PixelContrastLoss, self).__init__()
self.temperature = temperature
self.base_temperature = base_temperature
self.ignore_index = ignore_index
self.max_samples = max_samples
self.max_views = max_views
self.ndim = ndim
def _hard_anchor_sampling(
self, X: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample hard anchors from the input features and their corresponding labels.
Parameters
----------
X : torch.Tensor
Input features of shape (batch_size, num_samples, feature_dim).
E.g. (2, 32768, 256) for a batch size of 2, 32768 samples and 256 features.
y_hat : torch.Tensor
Ground truth labels of shape (batch_size, num_samples).
E.g. (2, 32768) for a batch size of 2 and 32768 samples.
y : torch.Tensor
Predicted labels of shape (batch_size, num_samples).
E.g. (2, 32768) for a batch size of 2 and 32768 samples.
Returns
-------
X_ : torch.Tensor
Sampled features of shape (total_classes, self.max_views, feature_dim).
E.g. (82, 1, 256) for 82 classes (this can vary depeding on the classes
found in the ground truth), and 256 features.
y_ : torch.Tensor
Sampled labels of shape (total_classes,). E.g. (82,) for 82 classes (this
can vary depeding on the classes found in the ground truth).
"""
batch_size, feat_dim = X.shape[0], X.shape[-1]
classes = []
total_classes = 0
for ii in range(batch_size):
this_y = y_hat[ii]
this_classes = torch.unique(this_y)
this_classes = [x for x in this_classes if x != self.ignore_index]
this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views]
classes.append(this_classes)
total_classes += len(this_classes)
if total_classes == 0:
return None, None # type: ignore
n_view = self.max_samples // total_classes
n_view = min(n_view, self.max_views)
X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float).cuda()
y_ = torch.zeros(total_classes, dtype=torch.float).cuda()
X_ptr = 0
for ii in range(batch_size):
this_y_hat = y_hat[ii]
this_y = y[ii]
this_classes = classes[ii]
for cls_id in this_classes:
hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero()
easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero()
num_hard = hard_indices.shape[0]
num_easy = easy_indices.shape[0]
if num_hard >= n_view / 2 and num_easy >= n_view / 2:
num_hard_keep = n_view // 2
num_easy_keep = n_view - num_hard_keep
elif num_hard >= n_view / 2:
num_easy_keep = num_easy
num_hard_keep = n_view - num_easy_keep
elif num_easy >= n_view / 2:
num_hard_keep = num_hard
num_easy_keep = n_view - num_hard_keep
else:
raise Exception("this shoud be never touched! {} {} {}".format(num_hard, num_easy, n_view))
perm = torch.randperm(num_hard)
hard_indices = hard_indices[perm[:num_hard_keep]]
perm = torch.randperm(num_easy)
easy_indices = easy_indices[perm[:num_easy_keep]]
indices = torch.cat((hard_indices, easy_indices), dim=0)
X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1)
y_[X_ptr] = cls_id
X_ptr += 1
return X_, y_
def _sample_negative(self, Q: torch.Tensor):
"""
Sample negative examples from the queue.
The queue is expected to be of shape (class_num, cache_size, feat_size), where:
- class_num is the number of classes,
- cache_size is the number of samples per class,
- feat_size is the size of the feature vector.
Parameters
----------
Q : torch.Tensor
Queue of shape (class_num, cache_size, feat_size).
E.g. (2, 60, 256) for 2 classes, 60 samples per class and 256 features.
Returns
-------
X_ : torch.Tensor
Sampled negative examples of shape (class_num * cache_size, feat_size).
E.g. (120, 256) for 2 classes, 60 samples per class and 256 features.
y_ : torch.Tensor
"""
class_num, cache_size, feat_size = Q.shape
X_ = torch.zeros((class_num * cache_size, feat_size)).float().cuda()
y_ = torch.zeros((class_num * cache_size, 1)).float().cuda()
sample_ptr = 0
for ii in range(class_num):
if ii == 0: continue
this_q = Q[ii, :cache_size, :]
X_[sample_ptr : sample_ptr + cache_size, ...] = this_q
y_[sample_ptr : sample_ptr + cache_size, ...] = ii
sample_ptr += cache_size
return X_, y_
def _contrastive(
self,
X_anchor: torch.Tensor,
y_anchor: torch.Tensor,
queue: Optional[torch.Tensor] = None,
):
"""
Contrastive loss calculation.
Parameters
----------
X_anchor : torch.Tensor
Anchor features of shape (total_classes, self.max_views, feature_dim).
E.g. (82, 1, 256) for 82 classes (this can vary depeding on the classes
found in the ground truth), and 256 features.
y_anchor : torch.Tensor
Anchor labels of shape (total_classes,). E.g. (82,) for 82 classes (this
can vary depeding on the classes found in the ground truth).
queue : torch.Tensor, optional
Queue of negative examples of shape (class_num, cache_size, feat_size).
E.g. (19, 10000, 256) for 19 classes, 10000 samples per class and 256 features.
If not provided, the contrastive loss will be calculated using the anchor features only.
"""
anchor_num, n_view = X_anchor.shape[0], X_anchor.shape[1]
y_anchor = y_anchor.contiguous().view(-1, 1)
anchor_count = n_view
anchor_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0)
if queue is not None:
X_contrast, y_contrast = self._sample_negative(queue)
y_contrast = y_contrast.contiguous().view(-1, 1)
contrast_count = 1
contrast_feature = X_contrast
else:
y_contrast = y_anchor
contrast_count = n_view
contrast_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0)
mask = torch.eq(y_anchor, y_contrast.T).float().cuda()
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
mask = mask.repeat(anchor_count, contrast_count)
neg_mask = 1 - mask
logits_mask = torch.ones_like(mask).scatter_(1, torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(), 0)
mask = mask * logits_mask
neg_logits = torch.exp(logits) * neg_mask
neg_logits = neg_logits.sum(1, keepdim=True)
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits + neg_logits)
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.mean()
return loss
[docs]
def forward(
self,
feats: torch.Tensor,
labels: torch.Tensor,
predict: torch.Tensor,
queue: Optional[torch.Tensor] = None,
):
"""
Forward pass of the Pixel Contrastive Loss.
Parameters
----------
feats : torch.Tensor
Input features of shape (batch_size, feat_size, H, W) or (batch_size, feat_size, D, H, W).
E.g. (2, 256, 128, 256) for a batch size of 2, 256 features and a spatial size of 128x256.
labels : torch.Tensor
Ground truth labels of shape (batch_size, C, H, W) or (batch_size, C, D, H, W).
E.g. (2, 1, 128, 256) for a batch size of 2, 1 channel and a
spatial size of 128x256.
predict : torch.Tensor
Predicted labels of shape (batch_size, H, W) or (batch_size, D, H, W).
E.g. (2, 128, 256) for a batch size of 2 and a spatial size of 128x256.
queue : torch.Tensor, optional
Queue of negative examples of shape (class_num, cache_size, feat_size).
E.g. (2, 60, 256) for 2 classes, 60 samples per class and 256 features.
If not provided, the contrastive loss will be calculated using the anchor features only.
Returns
-------
loss : torch.Tensor
Contrastive loss value.
"""
labels = torch.nn.functional.interpolate(labels.float().clone(), feats.shape[-self.ndim :], mode="nearest")
# When working in instance segmentation the channels are more than 1 so we need to merge then into
# just one channel.
if labels.shape[1] != 1:
if labels.ndim == 4:
offsets = torch.tensor([1, 2], device=labels.device).view(1, 2, 1, 1)
else:
offsets = torch.tensor([1, 2], device=labels.device).view(1, 2, 1, 1, 1)
labels = labels * offsets
labels, _ = labels.max(dim=1)
# In semantic the target is already compressed into one channel
else:
labels = labels.squeeze(1)
labels = labels.long()
assert labels.shape[-1] == feats.shape[-1], "Labels ({}) and features ({}) does not match in shape".format(
labels.shape, feats.shape
)
batch_size = feats.shape[0]
labels = labels.contiguous().view(batch_size, -1)
predict = predict.contiguous().view(batch_size, -1)
if feats.ndim == 4:
feats = feats.permute(0, 2, 3, 1)
else:
feats = feats.permute(0, 2, 3, 4, 1)
feats = feats.contiguous().view(feats.shape[0], -1, feats.shape[-1])
feats_, labels_ = self._hard_anchor_sampling(feats, labels, predict)
if feats_ is not None and labels_ is not None:
loss = self._contrastive(feats_, labels_, queue=queue)
return loss
else:
return 0
[docs]
class instance_segmentation_loss:
"""
Custom loss for instance segmentation tasks in BiaPy.
This loss combines different loss functions (e.g., BCE, L1, CrossEntropy) for multiple output channels,
such as binary masks, contours, distances, and class channels. It supports class rebalancing, masking
of distance channels, and different instance segmentation output types (e.g., "regular", "synapses").
The loss is configurable for various output channel combinations and can handle multi-class and
multi-head settings.
Parameters
----------
channel_weights : tuple of float, optional
Weights to be applied to each output channel loss. E.g. (1, 0.2).
out_channels : List of str, optional
String specifying the output channels (e.g., ["F", "C"], ["B", "C", "P"], ["B","C","D"], etc.).
losses_to_use : list of str, optional
List of loss functions to use for each output channel (e.g., ["BCE", "MSE"]).
channel_extra_opts : dict, optional
Additional options for each output channel (e.g., {"D": {"mask_values": True}}).
gt_channels_expected : int, optional
Number of channels expected in the ground truth (default: 1).
class_rebalance : str, optional
Whether to reweight classes (inside loss function) or not. Options are: "none" and "auto".
class_weights : List[float], optional
Weights for each class to be used in the loss calculation (default: None).
ignore_index : int, optional
Value to ignore in the loss calculation (default: -1).
"""
def __init__(
self,
channel_weights=(1, 1),
ndim: int = 2,
class_rebalance_within_channels: bool = False,
separated_class_channel: bool = False,
out_channels=["F", "C"],
losses_to_use=[],
channel_extra_opts={},
gt_channels_expected: int = 1,
class_rebalance: str = "none",
class_weights: List[float] = [],
ignore_index: int = -1,
device = None,
):
"""
Initialize the custom loss that mixed BCE and MSE depending on the ``out_channels`` variable.
Parameters
----------
channel_weights : 2 float tuple, optional
Weights to be applied to be applied to each channel of the data. E.g. if working with F + C channels,
you can provide for example these weights: ``(1, 2)`` so the contours will have more importance in the
loss calculation.
ndim : int, optional
Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes.
class_rebalance_within_channels : bool, optional
Whether to apply a rebalancing strategy to the loss function to give more importance to underrepresented pixels within the channels.
The weights are calculated automatically based on the number of pixels of each class. In the specific case of detection, where there
are usually much less pixels representing the center of the objects to detect than background pixels, with this option activated,
the loss will give more importance to the pixels representing the center of the objects to help the model learn better to predict them.
separated_class_channel : bool, optional
When a separated class channel is expected in the predictions e.g. instances + classification in instance segmentation.
out_channels : List of str, optional
Channels to operate with.
losses_to_use : list of str, optional
List of loss functions to use for each output channel (e.g., ["ce", "bce", "mae"]).
channel_extra_opts : dict, optional
Additional options for each output channel (e.g., {"B": {"mask_values": True}}).
gt_channels_expected : int, optional
Number of channels expected in the ground truth (default: 1). This is used to check that the GT loaded has the
expected number of channels. If ``extra_weight_in_borders`` is ``True``, then 1 channel will be added to the expected
GT channels to account for the extra weight in borders channel.
class_rebalance: str, optional
Whether to reweight classes or not. This only works if ``separated_class_channel`` is ``True``.
Options are: "none" and ``"manual"``.
class_weights : list of float, optional
List of weights for each class to be used in ``"manual"`` class rebalancing. This only works when ``separated_class_channel`` is ``True`` and
``class_rebalance`` is ``"manual"``. E.g. ``[1, 1.7, 0.5]`` for 3 classes.
ignore_index : int, optional
Value to ignore in the loss calculation. If not provided, no value will be ignored.
device : Torch device, optional
Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps".
"""
if separated_class_channel:
if class_rebalance == "manual" and not class_weights:
raise ValueError("class_weights must be provided when class_rebalance is 'manual'")
self.channel_weights = channel_weights
self.class_rebalance_within_channels = class_rebalance_within_channels
self.separated_class_channel = separated_class_channel
self.ndim = ndim
self.out_channels = [x for x in out_channels if x != "We"]
self.extra_weight_in_borders = out_channels.count("We") > 0
self.gt_channels_expected = gt_channels_expected if not self.extra_weight_in_borders else gt_channels_expected + 1
self.channel_extra_opts = channel_extra_opts
self.class_rebalance = class_rebalance
self.class_weights = None
self.ignore_index = ignore_index
self.ignore_values = True if ignore_index != -1 else False
self.losses_to_use = losses_to_use
self.device = device if device is not None else torch.device("cpu")
self.gamma = 0.5
if self.class_rebalance == "manual":
self.class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32)
if len(self.losses_to_use) != 0:
assert len(self.out_channels) == len(self.losses_to_use), "Length of out_channels and losses_to_use should be the same. Provided {} and {}".format(
self.out_channels, self.losses_to_use
)
if self.separated_class_channel:
self.class_channel_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights, reduction="none")
def _foreground_mask(self, y_true: torch.Tensor) -> "Optional[torch.Tensor]":
"""Return a (B, 1, ...) float foreground mask from binary GT channels.
Checks for F or M first (foreground=1), then B (background=1 → foreground=0).
Returns None when no binary channel is present in the GT.
"""
for j, ch in enumerate(self.out_channels):
if ch in ("F", "M"):
return (y_true[:, j : j + 1] > 0).float()
if ch == "B":
return (y_true[:, j : j + 1] == 0).float()
return None
def __call__(self, y_pred, y_true):
"""
Calculate instance segmentation loss.
Parameters
----------
y_true : torch.Tensor
Ground truth masks.
y_pred : torch.Tensor or list of Tensors
Predictions.
Returns
-------
loss : torch.Tensor
Loss value.
"""
if isinstance(y_pred, dict):
_y_pred = y_pred["pred"]
else:
_y_pred = y_pred
if self.separated_class_channel:
assert isinstance(y_pred, dict), "When separated_class_channel is True, y_pred should be a dict with 'pred' and 'class' keys. Provided type: {}".format(type(y_pred))
assert "class" in y_pred, "When separated_class_channel is True, y_pred should be a dict with 'pred' and 'class' keys. Provided keys: {}".format(y_pred.keys())
_y_pred_class = y_pred["class"]
assert (y_true.shape[1] == self.gt_channels_expected), (
"Seems that the GT loaded doesn't have {} channels as expected in {}. GT shape: {}".format(
self.gt_channels_expected, self.out_channels, y_true.shape
)
)
w_borders = None
if self.extra_weight_in_borders:
w_borders = y_true[:, -1]
if not isinstance(_y_pred, list):
_y_pred = [_y_pred]
_y_pred_class = [_y_pred_class] if self.separated_class_channel else None
inter_output_weights = [1.0]
else:
w = [self.gamma**i for i in range(len(_y_pred))]
s = sum(w)
inter_output_weights = [x / s for x in w]
loss = 0
for idx, pd in enumerate(_y_pred):
inter_output_loss = 0
for i, channel in enumerate(self.out_channels):
pred_ch_start = self.out_channels.index(channel)
gt_ch_start = pred_ch_start
if channel == "A":
pred_ch_end = pd.shape[1]
gt_ch_end = pred_ch_end
elif channel == "R":
pred_ch_end = self.channel_extra_opts["R"].get("nrays", 32) + pred_ch_start
gt_ch_end = pred_ch_end
elif channel == "Db":
val_type = self.channel_extra_opts.get("Db", {}).get("val_type", "norm")
if val_type == "discretize":
db_dis_bin_size = self.channel_extra_opts.get("Db", {}).get("bin_size", 0.1)
db_dis_K = int(round(1.0 / db_dis_bin_size)) # 10
db_channels = db_dis_K + 1
else:
db_channels = 1
pred_ch_end = pred_ch_start + db_channels
gt_ch_end = pred_ch_start + 1
else:
pred_ch_end = pred_ch_start + 1
gt_ch_end = pred_ch_end
y_pred_slice = pd[:, pred_ch_start:pred_ch_end]
y_true_slice = y_true[:, gt_ch_start:gt_ch_end].float()
# element-wise mask for the loss
mask_vals = self.channel_extra_opts.get(channel, {}).get("mask_values", False)
mask = None
if channel in ("Gv", "Gh", "Gz"):
# Flow targets are unit vectors: magnitude is always 1 in the foreground
# and (0, 0[, 0]) in the background. A foreground pixel with purely
# horizontal flow has Gv=0, so per-component (!=0) masking would wrongly
# exclude it. Sum of squared components is the correct foreground proxy.
flow_channels = [j for j, ch in enumerate(self.out_channels) if ch in ("Gv", "Gh", "Gz")]
mag_sq = sum(y_true[:, j : j + 1].float() ** 2 for j in flow_channels)
mask = (mag_sq > 0).float()
elif channel in ("H", "V", "Z"):
# HoVer-Net channels: values in [-1, 1] (signed), centroid = 0 = background.
# Masking by (!=0) would exclude the entire center row/column of every cell.
# Use a binary foreground channel (F/M/B) when present; otherwise train on
# all pixels — background=0 is a valid target and the network learns to
# predict 0 there naturally.
mask = self._foreground_mask(y_true)
elif channel in ("Db", "Dc", "Dn", "R"):
# Distance channels where legitimate foreground pixels can have value 0:
# Db: boundary pixels have Db=0 after per-cell normalization
# Dc: the centroid pixel has Dc=0 (most important point)
# Dn: isolated cells (no neighbor) have Dn=0
# R: ray distances near boundary approach 0
# Use a binary foreground channel (F/M/B) when available; fall back to
# (y_true_slice > 0) which excludes background but also misses genuine
# zero-valued foreground pixels (e.g. Dc at centroid, Db at boundary).
mask = self._foreground_mask(y_true)
if mask is None and mask_vals:
mask = (y_true_slice > 0).float()
elif mask_vals:
mask = (y_true_slice != 0).float()
if self.ignore_values:
mask = mask * (y_true_slice != self.ignore_index).float()
if y_pred_slice.shape[-self.ndim :] != y_true_slice.shape[-self.ndim :]:
y_true_slice = scale_target(y_true_slice, y_pred_slice.shape[-self.ndim :])
# class-rebalance / ignore_index weights for BCE
weight = None
if self.losses_to_use[i] in ["bce", "ce"] and channel in ["B","F","P","C","T","A","M","F_pre","F_post"]:
if self.class_rebalance_within_channels:
weight = weight_binary_ratio(y_true_slice).float()
if self.ignore_values:
ignore_mask = (y_true_slice != self.ignore_index).float()
weight = ignore_mask if weight is None else weight * ignore_mask
# instantiate criterion with no reduction so we can mask safely
if self.losses_to_use[i] == "bce":
crit = torch.nn.BCEWithLogitsLoss(weight=weight, reduction="none")
elif self.losses_to_use[i] == "ce":
crit = torch.nn.CrossEntropyLoss(weight=weight, reduction="none")
y_true_slice = y_true_slice.long().squeeze(1)
elif self.losses_to_use[i] in ["l1", "mae"]:
crit = torch.nn.L1Loss(reduction="none")
elif self.losses_to_use[i] == "mse":
crit = torch.nn.MSELoss(reduction="none")
else:
raise ValueError("Loss function {} not recognized".format(self.losses_to_use[i]))
if self.losses_to_use[i] != "ce":
y_pred_slice = y_pred_slice.float()
y_true_slice = y_true_slice.float()
loss_tensor = crit(y_pred_slice, y_true_slice) # same shape as slice
# multiply by spatial border weights after crit
if w_borders is not None:
loss_tensor = loss_tensor * w_borders
# apply optional element mask AFTER computing the per-element loss
if mask is not None:
loss_tensor = loss_tensor * mask
denom = mask.sum().clamp_min(1.0)
else:
denom = torch.tensor(loss_tensor.numel(), device=loss_tensor.device, dtype=loss_tensor.dtype)
channel_loss_val = loss_tensor.sum() / denom
inter_output_loss += self.channel_weights[i] * channel_loss_val
if self.separated_class_channel:
loss_tensor = self.class_channel_loss(_y_pred_class[idx], y_true[:, -1].type(torch.long))
# multiply by spatial border weights after crit
if w_borders is not None:
loss_tensor = loss_tensor * w_borders
# Apply always a mask to the class channel to consider only the pixels where there is an instance (non-zero in the GT)
# for the loss calculation, otherwise the loss value can be very low and not contribute to the training.
if mask is None:
mask = (y_true[:, -1] != 0).float()
loss_tensor = loss_tensor * mask
denom = mask.sum().clamp_min(1.0)
loss += self.channel_weights[-1] * (loss_tensor.sum() / denom)
loss += inter_output_weights[idx] * inter_output_loss
return loss
[docs]
def detection_metrics(
true_points,
pred_points,
true_classes=None,
pred_classes=None,
tolerance=10,
resolution: List[int | float] = [1, 1, 1],
bbox_to_consider=[],
verbose=False,
) -> Tuple[Dict[str, float], pd.DataFrame, pd.DataFrame]:
"""
Calculate detection metrics (precision, recall, F1) for point-based object detection.
Parameters
----------
true_points : List of list
List containing coordinates of ground truth points. E.g. ``[[5,3,2], [4,6,7]]``.
pred_points : 4D Tensor
List containing coordinates of predicted points. E.g. ``[[5,3,2], [4,6,7]]``.
true_classes : List of ints, optional
Classes of each ground truth points.
pred_classes : List of ints, optional
Classes of each predicted points.
tolerance : optional, int
Maximum distance far away from a GT point to consider a point as a true positive.
resolution : List of int/float
Weights to be multiply by each axis. Useful when dealing with anysotropic data to reduce the distance value
on the axis with less resolution. E.g. ``(1,1,0.5)``.
bbox_to_consider : List of tuple/list, optional
To not take into account during metric calculation to those points outside the bounding box defined with
this variable. Order is: ``[[z_min, z_max], [y_min, y_max], [x_min, x_max]]``. For example, using an image
of ``10x100x200`` to not take into account points on the first/last slices and with a border of ``15`` pixel
for ``x`` and ``y`` axes, this variable could be defined as follows: ``[[1, 9], [15, 85], [15, 185]]``.
verbose : bool, optional
To print extra information.
Returns
-------
metrics : List of strings
List containing precision, accuracy and F1 between the predicted points and ground truth.
"""
if len(bbox_to_consider) > 0:
assert len(bbox_to_consider) == 3, "'bbox_to_consider' need to be of length 3"
assert [len(x) == 2 for x in bbox_to_consider], (
"'bbox_to_consider' needs to be a list of " "two element array/tuple. E.g. [[1,1],[15,100],[10,200]]"
)
if true_classes is not None and pred_classes is None:
raise ValueError("'pred_classes' must be provided when 'true_classes' is set")
if true_classes is not None and pred_classes is not None:
if len(true_classes) != len(true_points):
raise ValueError("'true_points' and 'true_classes' length must be the same")
if len(pred_classes) != len(pred_points):
raise ValueError("'pred_points' and 'pred_classes' length must be the same")
class_metrics = True
else:
class_metrics = False
_true = np.array(true_points, dtype=np.float32)
if len(pred_points) > 0:
_pred = np.array(pred_points, dtype=np.float32)
else:
_pred = np.zeros((0, 3), dtype=np.float32)
TP, FP, FN = 0, 0, 0
tag = ["FN" for x in _true]
fp_preds = list(range(1, len(_pred) + 1))
dis = [-1 for x in _true]
pred_id_assoc = [-1 for x in _true]
TP_not_considered = 0
if len(_true) > 0:
# Multiply each axis for the its real value
for i in range(len(resolution)):
_true[:, i] *= resolution[i]
_pred[:, i] *= resolution[i]
# Create cost matrix
distances = distance_matrix(_pred, _true)
n_matched = min(len(_true), len(_pred))
costs = -(distances >= tolerance).astype(float) - distances / (2 * n_matched)
pred_ind, true_ind = linear_sum_assignment(-costs)
# Analyse which associations are below the tolerance to consider them TP
for i in range(len(pred_ind)):
# Filter out those point outside the defined bounding box
consider_point = False
if len(bbox_to_consider) > 0:
point = true_points[true_ind[i]]
if (
bbox_to_consider[0][0] <= point[0] <= bbox_to_consider[0][1]
and bbox_to_consider[1][0] <= point[1] <= bbox_to_consider[1][1]
and bbox_to_consider[2][0] <= point[2] <= bbox_to_consider[2][1]
):
consider_point = True
else:
consider_point = True
if distances[pred_ind[i], true_ind[i]] < tolerance:
if consider_point:
TP += 1
tag[true_ind[i]] = "TP"
else:
tag[true_ind[i]] = "NC"
TP_not_considered += 1
fp_preds.remove(pred_ind[i] + 1)
dis[true_ind[i]] = distances[pred_ind[i], true_ind[i]]
pred_id_assoc[true_ind[i]] = pred_ind[i] + 1
if TP_not_considered > 0:
print(f"{TP_not_considered} TPs not considered due to filtering")
FN = len(_true) - TP - TP_not_considered
# FP filtering
FP_not_considered = 0
fp_tags = ["FP" for x in fp_preds]
if len(bbox_to_consider) > 0:
for i in range(len(fp_preds)):
point = pred_points[fp_preds[i] - 1]
if not (
bbox_to_consider[0][0] <= point[0] <= bbox_to_consider[0][1]
and bbox_to_consider[1][0] <= point[1] <= bbox_to_consider[1][1]
and bbox_to_consider[2][0] <= point[2] <= bbox_to_consider[2][1]
):
FP_not_considered += 1
fp_tags[i] = "NC"
print(f"{FP_not_considered} FPs not considered due to filtering")
FP = len(fp_preds) - FP_not_considered
# Create two dataframes with the GT and prediction points association made and another one with the FPs
df, df_fp = None, None
df_columns =[
"gt_id",
"pred_id",
"distance",
"tag",
"axis-0",
"axis-1",
"axis-2",
"gt_class",
"pred_axis-0",
"pred_axis-1",
"pred_axis-2",
"pred_class",
]
if len(_true) > 0:
_true = np.array(true_points, dtype=np.float32)
if len(pred_points) > 0:
_pred = np.array(pred_points, dtype=np.float32)
else:
_pred = np.zeros((0, 3), dtype=np.float32)
# Capture FP coords
fp_coords = np.zeros((len(fp_preds), _pred.shape[-1]))
pred_fp_class = [-1] * len(fp_preds)
for i in range(len(fp_preds)):
fp_coords[i] = _pred[fp_preds[i] - 1]
if class_metrics:
assert pred_classes is not None
pred_fp_class[i] = int(pred_classes[fp_preds[i] - 1])
# Capture prediction coords
pred_coords = np.zeros((len(pred_id_assoc), _pred.shape[-1]), dtype=np.float32)
pred_class = [-1] * len(pred_id_assoc)
if not class_metrics:
true_classes = [-1] * len(pred_id_assoc)
for i in range(len(pred_id_assoc)):
if pred_id_assoc[i] != -1:
pred_coords[i] = _pred[pred_id_assoc[i] - 1]
if class_metrics:
assert pred_classes is not None
pred_class[i] = int(pred_classes[pred_id_assoc[i] - 1])
else:
pred_coords[i] = [0] * _pred.shape[-1]
df = pd.DataFrame(
zip(
list(range(1, len(_true) + 1)),
pred_id_assoc,
dis,
tag,
_true[..., 0],
_true[..., 1],
_true[..., 2],
true_classes, # type: ignore
pred_coords[..., 0],
pred_coords[..., 1],
pred_coords[..., 2],
pred_class,
), # type: ignore
columns=df_columns,
)
df_fp = pd.DataFrame(
zip(
fp_preds,
fp_coords[..., 0],
fp_coords[..., 1],
fp_coords[..., 2],
fp_tags,
pred_fp_class,
),
columns=["pred_id", "axis-0", "axis-1", "axis-2", "tag", "pred_class"],
)
try:
precision = TP / (TP + FP)
except:
precision = 0
try:
recall = TP / (TP + FN)
except:
recall = 0
try:
F1 = 2 * ((precision * recall) / (precision + recall))
except:
F1 = 0
if not class_metrics:
if df is not None:
df = df.drop(columns=["gt_class", "pred_class"])
if df_fp is not None:
df_fp = df_fp.drop(columns=["pred_class"])
else:
if df is not None:
# Class metrics must be computed only on true detections (TP by distance),
# not on all Hungarian associations.
tp_df = df[df["tag"] == "TP"]
if len(tp_df) > 0:
TP_classes = int((tp_df["gt_class"] == tp_df["pred_class"]).sum())
FN_classes = int((tp_df["gt_class"] != tp_df["pred_class"]).sum())
else:
TP_classes = 0
FN_classes = 0
else:
TP_classes = 0
FN_classes = 0
try:
precision_classes = TP_classes / (TP_classes + FN_classes)
except:
precision_classes = 0
try:
recall_classes = TP_classes / (TP_classes + FN_classes)
except:
recall_classes = 0
try:
F1_classes = 2 * ((precision_classes * recall_classes) / (precision_classes + recall_classes))
except:
F1_classes = 0
if verbose:
if len(bbox_to_consider) > 0:
print(
"Points in ground truth: {} ({} total but {} not considered), Points in prediction: {} "
"({} total but {} not considered)".format(
len(_true),
len(true_points),
TP_not_considered,
len(_pred),
len(pred_points),
FP_not_considered,
)
)
else:
print("Points in ground truth: {}, Points in prediction: {}".format(len(_true), len(_pred)))
print("True positives: {}, False positives: {}, False negatives: {}".format(int(TP), int(FP), int(FN)))
if class_metrics:
print("True positives (class): {}, False negatives (class): {}".format(int(TP_classes), int(FN_classes)))
if not class_metrics:
r_dict = {
"Precision": precision,
"Recall": recall,
"F1": F1,
"TP": int(TP),
"FP": int(FP),
"FN": int(FN),
}
else:
r_dict = {
"Precision": precision,
"Recall": recall,
"F1": F1,
"TP": int(TP),
"FP": int(FP),
"FN": int(FN),
"Precision (class)": precision_classes,
"Recall (class)": recall_classes,
"F1 (class)": F1_classes,
"TP (class)": int(TP_classes),
"FN (class)": int(FN_classes),
}
if df is None:
if "gt_class" in df_columns:
df_columns.remove("gt_class")
if "pred_class" in df_columns:
df_columns.remove("pred_class")
df = pd.DataFrame(columns=df_columns)
if df_fp is None:
df_fp = pd.DataFrame(columns=["pred_id", "axis-0", "axis-1", "axis-2", "tag"])
return r_dict, df, df_fp
[docs]
class SSIM_loss(torch.nn.Module):
"""SSIM loss using torchmetrics StructuralSimilarityIndexMeasure."""
def __init__(self, data_range, device):
"""
Initialize the SSIM_loss module.
Parameters
----------
data_range : float
The value range of the input images (e.g., 1.0 or 255).
device : torch.device
Device to use for computation.
"""
super(SSIM_loss, self).__init__()
self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device, non_blocking=True)
[docs]
def forward(self, input, target):
"""
Compute the SSIM loss.
Parameters
----------
input : torch.Tensor or dict
Predicted images. If a dict, expects the prediction under the "pred" key.
target : torch.Tensor
Ground truth images.
Returns
-------
loss : torch.Tensor
1 minus the SSIM value (so that lower is better).
"""
if isinstance(input, dict):
input = input["pred"]
return 1 - self.ssim(input, target)
[docs]
class W_MAE_SSIM_loss(torch.nn.Module):
"""
Weighted combination of MAE and SSIM loss.
This loss combines Mean Absolute Error (MAE) and Structural Similarity Index Measure (SSIM)
for image regression tasks, allowing the user to balance pixel-wise and perceptual similarity.
"""
def __init__(self, data_range, device, w_mae=0.5, w_ssim=0.5):
"""
Initialize the W_MAE_SSIM_loss module.
Parameters
----------
data_range : float
The value range of the input images (e.g., 1.0 or 255).
device : torch.device
Device to use for computation.
w_mae : float, optional
Weight for the MAE loss component (default: 0.5).
w_ssim : float, optional
Weight for the SSIM loss component (default: 0.5).
"""
super(W_MAE_SSIM_loss, self).__init__()
self.w_mae = w_mae
self.w_ssim = w_ssim
self.mse = torch.nn.L1Loss().to(device, non_blocking=True)
self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device, non_blocking=True)
[docs]
def forward(self, input, target):
"""
Compute the weighted sum of MAE and SSIM loss.
Parameters
----------
input : torch.Tensor or dict
Predicted images. If a dict, expects the prediction under the "pred" key.
target : torch.Tensor
Ground truth images.
Returns
-------
loss : torch.Tensor
Weighted sum of MAE and (1 - SSIM) loss.
"""
if isinstance(input, dict):
input = input["pred"]
return (self.mse(input, target) * self.w_mae) + ((1 - self.ssim(input, target)) * self.w_ssim)
[docs]
class W_MSE_SSIM_loss(torch.nn.Module):
"""
Weighted combination of MSE and SSIM loss.
This loss combines Mean Squared Error (MSE) and Structural Similarity Index Measure (SSIM)
for image regression tasks, allowing the user to balance pixel-wise and perceptual similarity.
"""
def __init__(self, data_range, device, w_mse=0.5, w_ssim=0.5):
"""
Initialize the W_MSE_SSIM_loss module.
Parameters
----------
data_range : float
The value range of the input images (e.g., 1.0 or 255).
device : torch.device
Device to use for computation.
w_mse : float, optional
Weight for the MSE loss component (default: 0.5).
w_ssim : float, optional
Weight for the SSIM loss component (default: 0.5).
"""
super(W_MSE_SSIM_loss, self).__init__()
self.w_mse = w_mse
self.w_ssim = w_ssim
self.mse = torch.nn.MSELoss().to(device, non_blocking=True)
self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device, non_blocking=True)
[docs]
def forward(self, input, target):
"""
Compute the weighted sum of MSE and SSIM loss.
Parameters
----------
input : torch.Tensor or dict
Predicted images. If a dict, expects the prediction under the "pred" key.
target : torch.Tensor
Ground truth images.
Returns
-------
loss : torch.Tensor
Weighted sum of MSE and (1 - SSIM) loss.
"""
if isinstance(input, dict):
input = input["pred"]
return (self.mse(input, target) * self.w_mse) + ((1 - self.ssim(input, target)) * self.w_ssim)
[docs]
def n2v_loss_mse(y_pred, y_true):
"""
Noise2Void MSE loss for self-supervised denoising.
Parameters
----------
y_pred : torch.Tensor or dict
Predicted output.
y_true : torch.Tensor
Ground truth and mask.
Returns
-------
loss : torch.Tensor
Loss value.
"""
if isinstance(y_pred, dict):
y_pred = y_pred["pred"]
target = y_true[:, :y_pred.shape[1]]
mask = y_true[:, y_pred.shape[1]:]
loss = torch.sum(torch.square(target - y_pred * mask)) / torch.sum(mask)
return loss
[docs]
class SSIM_wrapper:
"""Wrapper for SSIM loss using pytorch_msssim."""
def __init__(self):
"""Initiate wrapper to SSIM loss function."""
self.loss = SSIM(data_range=1, size_average=True, channel=1)
def __call__(self, y_pred, y_true):
"""
Calculate instance segmentation loss.
Parameters
----------
y_true : torch.Tensor
Ground truth masks.
y_pred : torch.Tensor or list of Tensors
Predictions.
Returns
-------
loss : torch.Tensor
Loss value.
"""
if isinstance(y_pred, dict):
y_pred = y_pred["pred"]
return 1 - self.loss(y_pred, y_true)
[docs]
def lovasz_hinge(logits: torch.Tensor,
labels: torch.Tensor,
per_image: bool = True,
ignore_index: int | None = None) -> torch.Tensor:
"""
Single-function binary Lovász hinge loss.
- logits: unnormalized scores, same shape as labels (e.g. (1,H,W) or (1,D,H,W))
- labels: {0,1} or bool tensor, same shape as logits
- per_image: average loss per-item if a batch dim exists
- ignore_index: label value to ignore (optional)
"""
if logits.shape != labels.shape:
raise ValueError(f"Shape mismatch: logits {logits.shape} vs labels {labels.shape}")
# Handle per-image averaging if a batch dim is present
if per_image and logits.dim() >= 2 and logits.size(0) > 1:
losses = []
for li, yi in zip(logits, labels):
# Flatten and optionally filter ignore_index
l_flat = li.reshape(-1)
y_flat = yi.to(dtype=torch.long, device=li.device).reshape(-1)
if ignore_index is not None:
valid = (y_flat != ignore_index)
l_flat = l_flat[valid]
y_flat = y_flat[valid]
if l_flat.numel() == 0:
continue
# Signs in {-1,+1}, hinge errors, sort desc
signs = y_flat.float() * 2 - 1
errors = 1 - l_flat * signs
errs_sorted, perm = torch.sort(errors, descending=True)
y_sorted = y_flat[perm].float()
# Lovász gradient (Jaccard) in-place, no helpers
p = y_sorted.numel()
if p == 0:
continue
gts = y_sorted.sum()
inter = gts - y_sorted.cumsum(0)
union = gts + (1 - y_sorted).cumsum(0)
jacc = 1.0 - inter / torch.clamp_min(union, 1.0)
if p > 1:
jacc[1:p] = jacc[1:p] - jacc[0:p-1]
losses.append(F.relu(errs_sorted) @ jacc)
return (torch.stack(losses).mean() if len(losses) else logits.new_tensor(0.0))
# Single item (or per_image=False): same steps without the loop
l_flat = logits.reshape(-1)
y_flat = labels.to(dtype=torch.long, device=logits.device).reshape(-1)
if ignore_index is not None:
valid = (y_flat != ignore_index)
l_flat = l_flat[valid]
y_flat = y_flat[valid]
if l_flat.numel() == 0:
return logits.new_tensor(0.0)
signs = y_flat.float() * 2 - 1
errors = 1 - l_flat * signs
errs_sorted, perm = torch.sort(errors, descending=True)
y_sorted = y_flat[perm].float()
p = y_sorted.numel()
gts = y_sorted.sum()
inter = gts - y_sorted.cumsum(0)
union = gts + (1 - y_sorted).cumsum(0)
jacc = 1.0 - inter / torch.clamp_min(union, 1.0)
if p > 1:
jacc[1:p] = jacc[1:p] - jacc[0:p-1]
return F.relu(errs_sorted) @ jacc
[docs]
class SpatialEmbLoss(nn.Module):
"""
Spatial Embedding Loss for 2D and 3D inspired by `EmbedSeg <https://github.com/juglab/EmbedSeg/tree/main>`__.
Parameters
----------
patch_size : List of int, optional
Patch size used during training (used to build coordinate map buffer).
anisotropy : List of float or int, optional
Anisotropy factors for each axis (z,y,x).
ndims : int, optional
Number of spatial dimensions (2 or 3).
center_mode : str, optional
Method to compute object center: "centroid" or "medoid".
medoid_max_points : int, optional
Maximum number of points to use when computing medoid (to avoid O(N^2) complexity).
channel_weights : List of float, optional
Weights for the different loss components: [foreground, instance, variance, seed].
"""
def __init__(
self,
patch_size: List[int] = [32, 1024, 1024],
anisotropy: List[float | int] = [1,1,1],
ndims: int = 2,
center_mode: str = "centroid", # "centroid" or "medoid"
medoid_max_points: Optional[int] = 10000, # cap to avoid O(N^2) on huge objects
channel_weights: List[float] = [1.0, 1.0, 1.0],
):
super().__init__()
self.ndims = ndims
self.center_mode = center_mode
self.medoid_max_points = medoid_max_points
# Grid sizes (used to build the coordinate map buffer; sliced to input size on forward)
grid_z = patch_size[0] if ndims == 3 else 1
grid_y = patch_size[-3]
grid_x = patch_size[-2]
# Pixel sizes (coordinate extents)
pixel_z = anisotropy[0] if ndims == 3 else 1
pixel_y = anisotropy[1]
pixel_x = anisotropy[2]
self.channel_weights = channel_weights
self.foreground_weight = self.channel_weights[0]
self.w_inst = self.channel_weights[1]
self.w_var = self.channel_weights[2]
self.w_seed = self.channel_weights[3]
# Build max-size 3D coordinate grid buffer; for 2D we will slice z=1.
# This lets one class handle both 2D (uses x,y) and 3D (uses x,y,z).
xm = (
torch.linspace(0, pixel_x, grid_x)
.view(1, 1, 1, -1)
.expand(1, grid_z, grid_y, grid_x)
)
ym = (
torch.linspace(0, pixel_y, grid_y)
.view(1, 1, -1, 1)
.expand(1, grid_z, grid_y, grid_x)
)
zm = (
torch.linspace(0, pixel_z, grid_z)
.view(1, -1, 1, 1)
.expand(1, grid_z, grid_y, grid_x)
)
# Stack as (3, Z, Y, X); for 2D we’ll slice to (2, Y, X) at forward time.
xyzm = torch.cat((xm, ym, zm), 0) # (3, Z, Y, X)
self.register_buffer("xyzm", xyzm)
def _calculate_binary_iou(self, pred, label):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | (pred == 1)).sum()
if not union:
return 0
else:
iou = intersection.item() / union.item()
return iou
@torch.no_grad()
def _center_from_mask(
self,
coords: torch.Tensor, # (D, ...)
in_mask: torch.Tensor, # (1, ...)
) -> torch.Tensor:
"""
Compute object center from binary mask using centroid or medoid.
Parameters
----------
coords : torch.Tensor
Coordinate grid tensor of shape (D, ...), where D is the number of spatial dimensions
in_mask : torch.Tensor
Binary mask tensor of shape (1, ...), indicating the object pixels/voxels.
Returns
-------
center : torch.Tensor
Computed center coordinates of shape (D, 1, ..., 1).
"""
D = coords.size(0)
# Extract coordinates of all pixels/voxels in the instance: (N, D)
pts = coords[in_mask.expand_as(coords)].view(D, -1).t().contiguous() # (N, D)
if pts.numel() == 0:
# No pixels: fall back to zeros
return torch.zeros(D, *([1] * (coords.dim() - 1)), device=coords.device, dtype=coords.dtype)
if self.center_mode == "centroid" or pts.shape[0] == 1:
c = pts.mean(0) # (D,)
else:
# MEDOID: minimize sum of Euclidean distances to all other points
# Optionally sub-sample to keep cdist tractable
if self.medoid_max_points is not None and pts.shape[0] > self.medoid_max_points:
idx = torch.randperm(pts.shape[0], device=pts.device)[: self.medoid_max_points]
pts_sub = pts[idx]
dist = torch.cdist(pts_sub, pts_sub, p=2) # (M, M)
sums = dist.sum(dim=1)
best = torch.argmin(sums)
c = pts_sub[best] # approximate medoid
else:
dist = torch.cdist(pts, pts, p=2) # (N, N)
sums = dist.sum(dim=1)
best = torch.argmin(sums)
c = pts[best] # exact medoid
return c.view(D, *([1] * (coords.dim() - 1))) # (D, 1, ..., 1)
[docs]
def forward(
self,
prediction: torch.Tensor, # (B, C, H, W) or (B, C, D, H, W)
instances: torch.Tensor, # (B, H, W) or (B, D, H, W)
) -> Tuple[torch.Tensor, float, str]:
if prediction.dim() not in (4, 5):
raise ValueError(
f"Unsupported prediction tensor dimensionality {prediction.dim()}. "
"Expected 4D (B,C,H,W) or 5D (B,C,D,H,W)."
)
B = prediction.size(0)
D = prediction.dim() - 2 # number of spatial dims (2 or 3)
assert D in (2, 3), "Only 2D or 3D supported"
assert D == self.ndims, f"Model ndims={D} does not match loss ndims={self.ndims}"
# Spatial sizes
if D == 2:
H, W = prediction.size(2), prediction.size(3)
# coords: (2, H, W) from self.xyzm (3, Z, Y, X)
coords = self.xyzm[:2, 0, :H, :W].contiguous()
total_voxels = H * W
else:
Z, H, W = prediction.size(2), prediction.size(3), prediction.size(4)
# coords: (3, Z, H, W)
coords = self.xyzm[:3, :Z, :H, :W].contiguous()
total_voxels = Z * H * W
# Remove the extra channel dimension in instances
instances = instances[:, 0]
# Channel partition
emb_ch = D # 2 for 2D, 3 for 3D
sig_ch = self.ndims # equal to D
seed_ch = 1
loss = prediction.new_tensor(0.0)
for b in range(B):
# Spatial embedding (tanh) + coordinate grid
spatial_emb = torch.tanh(prediction[b, :emb_ch]) + coords # (D, ...)
sigma = prediction[b, emb_ch:emb_ch + sig_ch] # (D, ...)
seed_map = torch.sigmoid(
prediction[b, emb_ch + sig_ch: emb_ch + sig_ch + seed_ch]
) # (1, ...)
var_loss = prediction.new_tensor(0.0)
instance_loss = prediction.new_tensor(0.0)
seed_loss = prediction.new_tensor(0.0)
iou = prediction.new_tensor(0.0)
obj_count = 0
instance = instances[b].unsqueeze(0) # (1, ...)
instance_ids = instance.unique()
instance_ids = instance_ids[instance_ids != 0]
# Regress background seeds to zero
bg_mask = (instances[b] == 0).unsqueeze(0) # (1, ...)
if bg_mask.sum() > 0:
seed_loss = seed_loss + torch.sum(torch.pow(seed_map[bg_mask] - 0, 2))
for idv in instance_ids:
in_mask = instance.eq(idv) # (1, ...)
center = self._center_from_mask(coords, in_mask)
# Sigma stats on object pixels/voxels
sigma_in = sigma[in_mask.expand_as(sigma)].view(sig_ch, -1) # (D, N)
s_mean = sigma_in.mean(1).view(sig_ch, 1) # (D, 1)
# Variance loss (before exp), detaching mean to match originals
var_loss = var_loss + torch.mean(torch.pow(sigma_in - s_mean.detach(), 2))
# Distance field
s = torch.exp(s_mean.view(sig_ch, *([1] * D)) * 10) # (D, 1...1)
dist = torch.exp(
-1 * torch.sum(torch.pow(spatial_emb - center, 2) * s, dim=0, keepdim=True)
) # (1, ...)
# Instance (Lovász hinge) loss on the soft mask
instance_loss = instance_loss + lovasz_hinge(dist * 2 - 1, in_mask)
# Seed regression loss towards distance field (fg only)
seed_loss = seed_loss + self.foreground_weight * torch.sum(
torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2)
)
# Measure IoU at 0.5 threshold
iou += self._calculate_binary_iou(dist > 0.5, in_mask)
obj_count += 1
if obj_count > 0:
instance_loss = instance_loss / obj_count
var_loss = var_loss / obj_count
iou = iou / obj_count
seed_loss = seed_loss / total_voxels
loss = loss + (self.w_inst * instance_loss + self.w_var * var_loss + self.w_seed * seed_loss)
loss = loss / B
iou = iou / B
return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals