Source code for biapy.engine.metrics

"""
Metrics and loss functions for BiaPy.

This module provides a variety of metrics and loss functions for evaluating and training
deep learning models in BiaPy. It includes implementations for Jaccard index (IoU),
Dice loss, BCE, Cross-Entropy, contrastive losses, instance segmentation losses,
detection metrics, and wrappers for SSIM, MSE, and MAE-based losses. Both PyTorch and
NumPy-based metrics are supported for 2D and 3D biomedical image analysis.
"""
import torch
import numpy as np
import pandas as pd
from scipy.spatial import distance_matrix
from scipy.optimize import linear_sum_assignment
from torchmetrics import JaccardIndex
from torchmetrics.image import StructuralSimilarityIndexMeasure
from pytorch_msssim import SSIM
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional, List, Tuple, Dict, Union

[docs] def jaccard_index_numpy(y_true, y_pred): """ Compute the Jaccard index (Intersection over Union) between ground truth and prediction. Parameters ---------- y_true : N dim Numpy array Ground truth masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or ``(volume_number, z, x, y, channels)`` for 3D volumes. y_pred : N dim Numpy array Predicted masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or ``(volume_number, z, x, y, channels)`` for 3D volumes. Returns ------- jac : float Jaccard index value. """ if y_true.ndim != y_pred.ndim: raise ValueError("Dimension mismatch: {} and {} provided".format(y_true.shape, y_pred.shape)) TP = np.count_nonzero(y_pred * y_true) FP = np.count_nonzero(y_pred * (y_true - 1)) FN = np.count_nonzero((y_pred - 1) * y_true) if (TP + FP + FN) == 0: jac = 0 else: jac = TP / (TP + FP + FN) return jac
[docs] def jaccard_index_numpy_without_background(y_true, y_pred): """ Compute Jaccard index excluding the background class (first channel). Parameters ---------- y_true : N dim Numpy array Ground truth masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or ``(volume_number, z, x, y, channels)`` for 3D volumes. y_pred : N dim Numpy array Predicted masks. E.g. ``(num_of_images, x, y, channels)`` for 2D images or ``(volume_number, z, x, y, channels)`` for 3D volumes. Returns ------- jac : float Jaccard index value. """ if y_true.ndim != y_pred.ndim: raise ValueError("Dimension mismatch: {} and {} provided".format(y_true.shape, y_pred.shape)) TP = np.count_nonzero(y_pred[..., 1:] * y_true[..., 1:]) FP = np.count_nonzero(y_pred[..., 1:] * (y_true[..., 1:] - 1)) FN = np.count_nonzero((y_pred[..., 1:] - 1) * y_true[..., 1:]) if (TP + FP + FN) == 0: jac = 0 else: jac = TP / (TP + FP + FN) return jac
[docs] def weight_binary_ratio(target): """ Compute a weight map to balance foreground and background pixels. Parameters ---------- target : torch.Tensor Target tensor. Returns ------- weight : torch.Tensor Weight map. """ if torch.max(target) == torch.min(target): return torch.ones_like(target, dtype=torch.float32) # Generate weight map by balancing the foreground and background. min_ratio = 5e-2 label = target.clone() # copy of target label label = (label != 0).double() # foreground ww = label.sum() / torch.prod(torch.tensor(label.shape, dtype=torch.double)) ww = torch.clamp(ww, min=min_ratio, max=1 - min_ratio) weight_factor = max(ww, 1 - ww) / min(ww, 1 - ww) # type: ignore # Case 1 -- Affinity Map # In that case, ww is large (i.e., ww > 1 - ww), which means the high weight # factor should be applied to background pixels. # Case 2 -- Contour Map # In that case, ww is small (i.e., ww < 1 - ww), which means the high weight # factor should be applied to foreground pixels. if ww > 1 - ww: # Switch when foreground is the dominant class. label = 1 - label weight = weight_factor * label + (1 - label) return weight.float()
[docs] class jaccard_index: """ Jaccard index (IoU) metric for PyTorch tensors. Supports binary and multiclass segmentation, with optional thresholding and ignore index. """ def __init__( self, num_classes: int, device: torch.device, t: float = 0.5, model_source: str = "biapy", ndim: int = 2, ignore_index: int = -1, ): """ Define Jaccard index. Parameters ---------- num_classes : int Number of classes. device : Torch device Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", "xpu", "xla" or "meta". t : float, optional Threshold to be applied. model_source : str, optional Source of the model. It can be "biapy", "bmz" or "torchvision". ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. ignore_index : int, optional Value to ignore in the loss calculation. If not provided, no value will be ignored. """ self.model_source = model_source self.loss = torch.nn.CrossEntropyLoss() self.device = device self.num_classes = num_classes self.t = t self.ndim = ndim self.ignore_index = ignore_index if ignore_index != -1 else None if self.num_classes > 2: self.jaccard = JaccardIndex( task="multiclass", threshold=self.t, num_classes=self.num_classes, ignore_index=self.ignore_index ).to(self.device, non_blocking=True) else: self.jaccard = JaccardIndex( task="binary", threshold=self.t, num_classes=self.num_classes, ignore_index=self.ignore_index ).to(self.device, non_blocking=True) def __call__(self, y_pred, y_true): """ Calculate Jaccard index (intersection over union). Parameters ---------- y_true : torch.Tensor Ground truth masks. y_pred : torch.Tensor Predicted masks. Returns ------- jaccard : torch.Tensor Jaccard index value. """ _y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred # For those cases that are predicting 2 channels (binary case) we adapt the GT to match. # It's supposed to have 0 value as background and 1 as foreground if self.model_source == "bmz" and self.num_classes <= 2 and _y_pred.shape[1] != y_true.shape[1]: y_true = torch.cat((1 - y_true, y_true), 1) if not isinstance(_y_pred, list): _y_pred = [_y_pred] iou = 0 for j, pd in enumerate(_y_pred): _y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true if self.num_classes > 2: if pd.shape[1] > 1: _y_true = _y_true.squeeze() if len(pd.shape) - 2 == len(_y_true.shape): _y_true = _y_true.unsqueeze(0) iou += self.jaccard(pd, _y_true.long() if _y_true.is_floating_point() else _y_true) return iou/len(_y_pred)
[docs] class multiple_metrics: """ Compute multiple metrics for instance segmentation workflows. Supports IoU, L1, and other metrics for multi-head or multi-channel outputs. """ def __init__( self, num_classes: int, metric_names: List[str], device: torch.device, out_channels: Optional[List[str]]=["F"], channel_extra_opts: Optional[Dict]={}, ignore_index: int = -1, model_source: str = "biapy", ndim: int = 2, ): """ Define instance segmentation workflow metrics. Parameters ---------- num_classes : int Number of classes. metric_names : list of str Names of the metrics to use. device : Torch device Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", "xpu", "xla" or "meta". out_channels : list of str, optional Output channels to be predicted. E.g. ["F", "C"] for foreground and class channels. channel_extra_opts : dict, optional Additional options for each output channel (e.g., {"B": {"mask_values": True}}). ignore_index : int, optional Value to ignore in the loss calculation. If not provided, no value will be ignored. model_source : str, optional Source of the model. It can be "biapy", "bmz" or "torchvision". ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. """ self.num_classes = num_classes self.metric_names = metric_names self.device = device self.out_channels = out_channels.copy() if out_channels is not None else [".",]*len(metric_names) if self.num_classes > 2: self.out_channels += ["class"] self.out_channels = [x for x in self.out_channels if x != "We"] # Ignore weight extra channel self.channel_extra_opts = channel_extra_opts self.model_source = model_source self.ignore_index = ignore_index if ignore_index != -1 else None self.ndim = ndim self.metric_func = [] for i in range(len(metric_names)): if "IoU (classes)" in metric_names[i]: loss_func = JaccardIndex( task="multiclass", threshold=0.5, num_classes=self.num_classes, ignore_index=self.ignore_index ).to(self.device, non_blocking=True) elif "IoU" in metric_names[i]: loss_func = JaccardIndex( task="binary", threshold=0.5, num_classes=2, ignore_index=self.ignore_index ).to(self.device, non_blocking=True) elif "L1" in metric_names[i]: loss_func = torch.nn.L1Loss() else: raise ValueError(f"Metric {metric_names[i]} not recognized.") self.metric_func.append(loss_func) def __call__(self, y_pred, y_true): """ Calculate metrics. Parameters ---------- y_true : torch.Tensor Ground truth masks. y_pred : torch.Tensor or list of Tensors Prediction. Returns ------- dict : dict Metrics and their values. """ if isinstance(y_pred, dict): _y_pred = y_pred["pred"] else: _y_pred = y_pred # Check multi-head if isinstance(y_pred, dict) and "class" in y_pred: _y_pred_class = torch.argmax(y_pred["class"], dim=1) else: # Just take the last channel as class prediction from the first output, which is assumed to be the main one if isinstance(_y_pred, list): _y_pred_class = _y_pred[0][:, -1] else: _y_pred_class = _y_pred[:, -1] if not isinstance(_y_pred, list): _y_pred = [_y_pred] res_metrics = {} for pd in _y_pred: _y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true db_val_type = "" for pred_ch_start, channel in enumerate(self.out_channels): gt_ch_start = pred_ch_start if channel == "A": pred_ch_end = pd.shape[1] gt_ch_end = pred_ch_end elif channel == "R": assert self.channel_extra_opts is not None and "R" in self.channel_extra_opts, "Rays channel options must be provided." pred_ch_end = self.channel_extra_opts["R"].get("nrays", 32) + pred_ch_start gt_ch_end = pred_ch_end elif channel == "Db": assert self.channel_extra_opts is not None and "Db" in self.channel_extra_opts, "Distance to border channel options must be provided." db_val_type = self.channel_extra_opts.get("Db", {}).get("val_type", "norm") if db_val_type == "discretize": db_dis_bin_size = self.channel_extra_opts.get("Db", {}).get("bin_size", 0.1) db_dis_K = int(round(1.0 / db_dis_bin_size)) # 10 db_channels = db_dis_K + 1 else: db_channels = 1 pred_ch_end = pred_ch_start + db_channels gt_ch_end = pred_ch_end else: pred_ch_end = pred_ch_start + 1 gt_ch_end = pred_ch_end if self.metric_names[pred_ch_start] not in res_metrics: res_metrics[self.metric_names[pred_ch_start]] = [] # Measure metric if self.metric_names[pred_ch_start] == "IoU (classes)": res_metrics[self.metric_names[pred_ch_start]].append(self.metric_func[pred_ch_start](_y_pred_class, _y_true[:, 1])) else: y_pred_slice = pd[:, pred_ch_start:pred_ch_end] y_true_slice = _y_true[:, gt_ch_start:gt_ch_end].float() if y_pred_slice.shape[1] != y_true_slice.shape[1] and "Db" == channel and db_val_type == "discretize": y_pred_slice = torch.argmax(y_pred_slice, dim=1).unsqueeze(1).float() y_true_slice = y_true_slice.float() res_metrics[self.metric_names[pred_ch_start]].append(self.metric_func[pred_ch_start](y_pred_slice, y_true_slice)) # Mean of same metric values for key, value in res_metrics.items(): if len(value) > 1: res_metrics[key] = torch.mean(torch.as_tensor(value)) else: res_metrics[key] = torch.as_tensor(value[0]) return res_metrics
[docs] def scale_target(targets_: torch.Tensor, scaled_size: Tuple[int, ...]) -> torch.Tensor: """ Scale the target masks to match the size of the predictions. Parameters ---------- targets_ : torch.Tensor Ground truth masks. scaled_size : tuple Size to scale the masks to. Returns ------- targets : torch.Tensor Scaled ground truth masks. """ targets = F.interpolate(targets_.clone(), size=scaled_size, mode="nearest") return targets
[docs] class loss_encapsulation(nn.Module): """Just a wrapper to any other common loss deataching the prediction from the dict given by the model.""" def __init__(self, loss): """ Initialize the loss_encapsulation module. Parameters ---------- loss : nn.Module or callable The loss function to wrap. """ super(loss_encapsulation, self).__init__() self.loss = loss
[docs] def forward(self, inputs, targets): """ Forward pass for the encapsulated loss. Parameters ---------- inputs : torch.Tensor or dict Model predictions. If a dict, expects the prediction under the "pred" key. targets : torch.Tensor Ground truth targets. Returns ------- loss : torch.Tensor Computed loss value. """ if isinstance(inputs, dict): inputs = inputs["pred"] return self.loss(inputs, targets)
[docs] class CrossEntropyLoss_wrapper: """ Wrapper for PyTorch's CrossEntropyLoss and BCEWithLogitsLoss with support for class rebalancing. """ def __init__( self, num_classes: int, ndim: int = 2, class_rebalance: str = "none", class_weights: List[float] = [], ignore_index: int = -1, device=None, ): """ Initialize wrapper to Pytorch's CrossEntropyLoss. Parameters ---------- num_classes : int Number of classes. ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. class_rebalance: str, optional Whether to reweight classes or not. Options are: "none" and "manual". class_weights : list of float, optional List of weights for each class to be used in "manual" class rebalancing. E.g. ``[0.7, 0.3]`` for 2 classes. ignore_index : int, optional Value to ignore in the loss calculation. If not provided, no value will be ignored. device : Torch device, optional Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps". """ self.ndim = ndim self.num_classes = num_classes self.class_rebalance = class_rebalance self.class_weights = None self.ignore_index = ignore_index if ignore_index != -1 else -100 # Default ignore index for CrossEntropyLoss self.device = device if device is not None else torch.device("cpu") # For intermediate outputs weighting self.gamma = 0.5 if self.class_rebalance == "manual": self.class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32) if num_classes <= 2: self.loss = torch.nn.BCEWithLogitsLoss(weight=self.class_weights) else: self.loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights) def __call__(self, y_pred, y_true): """ Calculate CrossEntropyLoss. Parameters ---------- y_true : torch.Tensor Ground truth masks. y_pred : torch.Tensor Predicted masks. Returns ------- loss : torch.Tensor Loss value. """ _y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred if not isinstance(_y_pred, list): _y_pred = [_y_pred] inter_output_weights = [1.0] else: w = [self.gamma**i for i in range(len(_y_pred))] s = sum(w) inter_output_weights = [x / s for x in w] loss = 0 for j, pd in enumerate(_y_pred): _y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true if self.num_classes <= 2: _loss = self.loss(pd, _y_true.type(torch.float32)) else: _loss = self.loss(pd, _y_true[:, 0].type(torch.long)) loss += _loss * inter_output_weights[j] return loss
[docs] class detection_loss: """ Loss for detection models with separated_class_channel. Combines BCE for the foreground and CrossEntropy for the class channel, with optional class rebalancing and channel weighting. """ def __init__( self, ndim: int = 2, class_rebalance_within_channels: bool = True, separated_class_channel: bool = False, channel_weights = (1, 1), class_rebalance: str = "none", class_weights: List[float] = [], ignore_index: int = -1, device=None, ): """ Initialize detection loss. Parameters ---------- ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. class_rebalance_within_channels : bool, optional Whether to apply a rebalancing strategy to the loss function to give more importance to underrepresented pixels within the channels. The weights are calculated automatically based on the number of pixels of each class. In the specific case of detection, where there are usually much less pixels representing the center of the objects to detect than background pixels, with this option activated, the loss will give more importance to the pixels representing the center of the objects to help the model learn better to predict them. separated_class_channel : bool, optional When a separated class channel is expected in the predictions e.g. points + classification in detection. channel_weights : 2 float tuple, optional Weights to be applied to each channel of the data, i.e., centroid detection and class. E.g. ``(1, 0.2)``. This only works if ``separated_class_channel`` is ``True``. class_rebalance: str, optional Whether to reweight classes or not. This only works if ``separated_class_channel`` is ``True``. Options are: "none" and ``"manual"``. class_weights : list of float, optional List of weights for each class to be used in ``"manual"`` class rebalancing. This only works when ``separated_class_channel`` is ``True`` and ``class_rebalance`` is ``"manual"``. E.g. ``[1, 1.7, 0.5]`` for 3 classes. ignore_index : int, optional Value to ignore in the loss calculation. If not provided, no value will be ignored. device : Torch device, optional Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps". """ if separated_class_channel: if class_rebalance == "manual" and not class_weights: raise ValueError("class_weights must be provided when class_rebalance is 'manual'") self.channel_weights = channel_weights if len(self.channel_weights) != 2: raise ValueError("channel_weights must be a tuple of 2 float values when separated_class_channel is True") else: self.channel_weights = (1, 0) self.ndim = ndim self.separated_class_channel = separated_class_channel self.class_rebalance_within_channels = class_rebalance_within_channels self.class_rebalance = class_rebalance self.class_weights = None self.ignore_index = ignore_index if ignore_index != -1 else -100 # Default ignore index for CrossEntropyLoss self.device = device if device is not None else torch.device("cpu") self.gamma = 0.5 if self.class_rebalance == "manual": self.class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32) self.centroid_loss = torch.nn.BCEWithLogitsLoss() if self.separated_class_channel: self.class_channel_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights, reduction="none") def __call__(self, y_pred, y_true): """ Calculate CrossEntropyLoss. Parameters ---------- y_true : torch.Tensor Ground truth masks. y_pred : torch.Tensor Predicted masks. Returns ------- loss : torch.Tensor Loss value. """ if self.separated_class_channel: _y_pred = y_pred["pred"] _y_pred_class = y_pred["class"] assert ( y_true.shape[1] == 2 ), f"In separated_class_channel setting the ground truth is expected to have 2 channels. Provided {y_true.shape}" else: _y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred if not isinstance(_y_pred, list): _y_pred = [_y_pred] _y_pred_class = [_y_pred_class] if self.separated_class_channel else None inter_output_weights = [1.0] else: w = [self.gamma**i for i in range(len(_y_pred))] s = sum(w) inter_output_weights = [x / s for x in w] loss = 0 for j, pd in enumerate(_y_pred): _y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true if self.class_rebalance_within_channels: weight_mask = weight_binary_ratio(_y_true[:, 0]) loss_fn = torch.nn.BCEWithLogitsLoss(weight=weight_mask) else: loss_fn = self.centroid_loss _loss = self.channel_weights[0] * loss_fn(pd[:, 0], _y_true[:, 0].type(torch.float32)) if self.separated_class_channel: mask = (_y_true[:, 0] != 0).float() if isinstance(self.ignore_index, (int, float)): mask = mask * (_y_true[:, 0] != self.ignore_index).float() loss_cls = self.class_channel_loss(_y_pred_class[j], _y_true[:, -1].type(torch.long)) loss_cls = loss_cls * mask loss_cls = loss_cls.sum() / mask.sum().clamp_min(1.0) _loss += self.channel_weights[1] * loss_cls loss += _loss * inter_output_weights[j] return loss
[docs] class DiceLoss(nn.Module): def __init__(self, batch_dice: bool = True, smooth: float = 1e-5): super().__init__() self.batch_dice = batch_dice self.smooth = smooth
[docs] def forward(self, y_pred, y_true): # 1. Apply Sigmoid or Softmax to predictions to get probabilities if y_pred.shape[1] == 1: y_pred = torch.sigmoid(y_pred) else: y_pred = torch.softmax(y_pred, dim=1) # Ensure y_true is one-hot encoded if predicting multiple classes if y_pred.shape != y_true.shape: y_true = torch.nn.functional.one_hot(y_true.to(torch.int64), num_classes=y_pred.shape[1]) y_true = y_true.permute(0, 4, 1, 2, 3).squeeze(1) # Adjust permute based on 2D/3D y_true = y_true.float() # 2. Decide which axes to sum over. # For a tensor [Batch, Channels, Height, Width] (2D) or [B, C, D, H, W] (3D): # If batch_dice is True, we sum over the Batch dimension (0) AND the spatial dimensions. # If False, we ONLY sum over the spatial dimensions, leaving Batch intact. axes = list(range(2, len(y_pred.shape))) # e.g., [2, 3] for 2D if self.batch_dice: axes = [0] + axes # e.g., [0, 2, 3] for 2D # 3. Calculate Intersection and Union intersection = torch.sum(y_pred * y_true, dim=axes) union = torch.sum(y_pred, dim=axes) + torch.sum(y_true, dim=axes) # 4. Compute Dice dice = (2.0 * intersection + self.smooth) / (union + self.smooth) # 5. Return loss (1 - Dice). Mean across remaining dimensions (Channels/Batches) return 1.0 - torch.mean(dice)
[docs] class DiceCELoss(nn.Module): """ Combines Cross Entropy (or Binary Cross Entropy) and Dice Loss. Supports multi-head (separated_class_channel), deep supervision lists, class rebalancing, and channel weighting. """ def __init__( self, num_classes: int, ndim: int = 2, separated_class_channel: bool = False, batch_dice: bool = True, smooth: float = 1e-5, model_source: str = "biapy", class_rebalance: str = "none", class_weights: List[float] = [], channel_weights: Tuple[float, float] = (1.0, 1.0), w_ce: float = 1.0, w_dice: float = 1.0, ignore_index: int = -1, device=None, ): """ Initialize DiceCELoss. Parameters ---------- num_classes : int Number of classes. ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. separated_class_channel : bool, optional For separated_class_channel predictions e.g. points + classification in detection. batch_dice : bool, optional Whether to calculate Dice loss across the batch dimension or not. smooth : float, optional Smoothing factor to avoid division by zero in Dice loss. model_source : str, optional Source of the model. It can be "biapy", "bmz" or "torchvision". class_rebalance: str, optional Whether to reweight classes (inside loss function) or not. Options are: "none", "auto" and "manual". class_weights : list of float, optional List of weights for each class to be used in "manual" class rebalancing. E.g. ``[0.7, 0.3]`` for 2 classes. channel_weights : tuple of float, optional Weights to be applied to segmentation (binary and contours) and to distances respectively. E.g. ``(1, 0.2)``, ``1`` should be multipled by ``BCE`` for the first two channels and ``0.2`` to ``MSE`` for the last channel. w_ce : float, optional Weight for the Cross Entropy component of the loss. w_dice : float, optional Weight for the Dice component of the loss. ignore_index : int, optional Value to ignore in the loss calculation. If not provided, no value will be ignored. device : Torch device, optional Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", "xpu", "xla" or "meta". """ super(DiceCELoss, self).__init__() self.num_classes = num_classes self.ndim = ndim self.separated_class_channel = separated_class_channel self.batch_dice = batch_dice self.smooth = smooth self.model_source = model_source self.class_rebalance = class_rebalance self.channel_weights = channel_weights self.w_ce = w_ce self.w_dice = w_dice self.ignore_index = ignore_index if ignore_index != -1 else -100 self.device = device if device is not None else torch.device("cpu") self.gamma = 0.5 # For intermediate outputs weighting self.class_weights_tensor = None if class_weights is not None and len(class_weights) > 0: self.class_weights_tensor = torch.tensor(class_weights, device=self.device, dtype=torch.float32) # Initialize standard CE/BCE if num_classes <= 2: self.loss = torch.nn.BCEWithLogitsLoss(weight=self.class_weights_tensor) else: self.loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights_tensor) # Initialize Classification Head Loss if self.separated_class_channel: self.class_channel_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index) def _compute_dice(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """Helper function to compute Dice loss on spatial outputs.""" if self.num_classes <= 2: pred_probs = torch.sigmoid(y_pred) dice_target = y_true.view_as(pred_probs).float() else: pred_probs = torch.softmax(y_pred, dim=1) if y_true.ndim == y_pred.ndim: y_true_squeezed = y_true.squeeze(1).long() else: y_true_squeezed = y_true.long() dice_target = F.one_hot(y_true_squeezed, num_classes=self.num_classes) if y_pred.ndim == 4: # 2D: (B, C, H, W) dice_target = dice_target.permute(0, 3, 1, 2) elif y_pred.ndim == 5: # 3D: (B, C, D, H, W) dice_target = dice_target.permute(0, 4, 1, 2, 3) dice_target = dice_target.float() axes = list(range(2, y_pred.ndim)) if self.batch_dice: axes = [0] + axes intersection = torch.sum(pred_probs * dice_target, dim=axes) union = torch.sum(pred_probs, dim=axes) + torch.sum(dice_target, dim=axes) dice = (2.0 * intersection + self.smooth) / (union + self.smooth) return 1.0 - torch.mean(dice)
[docs] def forward(self, y_pred: Union[torch.Tensor, dict, List], y_true: torch.Tensor) -> torch.Tensor: """ Calculates Dice + CE Loss mimicking CrossEntropyLoss_wrapper structure. Parameters ---------- y_pred : torch.Tensor, dict, or list of Tensors Model predictions. Can be a single tensor, a dict with "pred" and optionally " class" keys, or a list of tensors for deep supervision. y_true : torch.Tensor Ground truth masks. Returns ------- loss : torch.Tensor Combined Dice + CE loss value. """ # 1. Extract predictions dict if needed if self.separated_class_channel: _y_pred = y_pred["pred"] _y_pred_class = y_pred["class"] assert ( y_true.shape[1] == 2 ), f"In separated_class_channel setting the ground truth is expected to have 2 channels. Provided {y_true.shape}" else: _y_pred = y_pred["pred"] if isinstance(y_pred, dict) and "pred" in y_pred else y_pred # 2. Handle specific model source adaptations if self.model_source == "bmz" and self.num_classes <= 2 and _y_pred.shape[1] != y_true.shape[1]: y_true = torch.cat((1 - y_true, y_true), 1) # 3. Handle Lists for Deep Supervision if not isinstance(_y_pred, list): _y_pred = [_y_pred] _y_pred_class = [_y_pred_class] if self.separated_class_channel else None inter_output_weights = [1.0] else: w = [self.gamma**i for i in range(len(_y_pred))] s = sum(w) inter_output_weights = [x / s for x in w] # 4. Loop over outputs and calculate combined loss loss = 0 for j, pd in enumerate(_y_pred): # Assumes `scale_target` is defined globally in your environment _y_true = scale_target(y_true, pd.shape[-self.ndim :]) if pd.shape[-self.ndim :] != y_true.shape[-self.ndim :] else y_true # Determine CE loss function (handles 'auto' rebalancing) if self.class_rebalance == "auto": if self.separated_class_channel: weight_mask = weight_binary_ratio(_y_true[:, 0]) # Assumes global function loss_fn = torch.nn.BCEWithLogitsLoss(weight=weight_mask) else: if self.num_classes <= 2: weight_mask = weight_binary_ratio(_y_true) # Assumes global function loss_fn = torch.nn.BCEWithLogitsLoss(weight=weight_mask) else: loss_fn = self.loss else: loss_fn = self.loss # Calculate Loss based on routing if self.separated_class_channel: # Spatial branch (CE + Dice) ce_spatial = loss_fn(pd[:, 0], _y_true[:, 0].float()) dice_spatial = self._compute_dice(pd[:, 0].unsqueeze(1), _y_true[:, 0].unsqueeze(1)) combined_spatial_loss = (self.w_ce * ce_spatial) + (self.w_dice * dice_spatial) # Classification branch (CE only) class_loss = self.class_channel_loss(_y_pred_class[j], _y_true[:, -1].type(torch.long)) # Apply channel weights _loss = (self.channel_weights[0] * combined_spatial_loss) + (self.channel_weights[1] * class_loss) else: if self.num_classes <= 2: ce_l = loss_fn(pd, _y_true.type(torch.float32)) else: ce_l = loss_fn(pd, _y_true[:, 0].type(torch.long)) dice_l = self._compute_dice(pd, _y_true) _loss = (self.w_ce * ce_l) + (self.w_dice * dice_l) loss += _loss * inter_output_weights[j] return loss
[docs] class ContrastCELoss(nn.Module): """ Contrastive Cross Entropy Loss for semantic segmentation tasks. It mixes the main loss function and the constrastive loss. Parameters ---------- main_loss : nn.Module The main loss function to be used for the segmentation task. ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. Default is 2. weight : float, optional Weight for the contrastive loss. Default is 1.0. This weight is used to balance the contribution of the contrastive loss in the final loss calculation and can be adjusted based on the specific requirements of the task. ignore_index : int, optional Label to ignore in the loss calculation. Default is -1. """ def __init__( self, main_loss: nn.Module, ndim: int = 2, weight: float = 1.0, ignore_index: int = -1, ): """ Initialize the ContrastCELoss module. Parameters ---------- main_loss : nn.Module The main loss function to be used for the segmentation task. ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. Default is 2. weight : float, optional Weight for the contrastive loss. Default is 1.0. ignore_index : int, optional Label to ignore in the loss calculation. Default is -1. """ super(ContrastCELoss, self).__init__() self.ndim = ndim self.main_loss = main_loss self.contrast_criterion = PixelContrastLoss(ignore_index=ignore_index, ndim=ndim) self.loss_weight = weight
[docs] def forward(self, preds, target, with_embed=False): """ Forward pass of the Contrastive Cross Entropy Loss. Parameters ---------- preds : dict Dictionary containing the predictions from the model. It should contain: - "pred": Segmentation predictions. - "embed": Embedding predictions. - "segment_queue": Segment queues for contrastive learning. - "pixel_queue": Pixel queues for contrastive learning. target : torch.Tensor Ground truth segmentation masks. with_embed : bool, optional Whether to include the embedding in the loss calculation. Default is False. """ assert "pred" in preds, "Segmentation prediction is missing in the input dictionary." assert "embed" in preds, "Embedding prediction is missing in the input dictionary." seg = preds["pred"] embedding = preds["embed"] segment_queue = preds["segment_queue"] if "segment_queue" in preds else None pixel_queue = preds["pixel_queue"] if "pixel_queue" in preds else None if seg.shape[-self.ndim :] != target.shape[-self.ndim :]: mode = "bilinear" if self.ndim == 2 else "trilinear" pred = F.interpolate(input=seg, size=target.shape[-self.ndim :], mode=mode, align_corners=True) else: pred = seg loss = self.main_loss(pred, target) loss_contrast = 0 if segment_queue is not None and pixel_queue is not None: queue = torch.cat((segment_queue, pixel_queue), dim=1) # When the classes are less or equal 2 the background class channel is not added in BiaPy # so can't apply directly an argmax/max operation if seg.shape[1] <= 2: _, predict = seg.max(dim=1) if predict.ndim == 3: offsets = torch.tensor([1, 2], device=seg.device).view(1, 2, 1, 1) else: offsets = torch.tensor([1, 2], device=seg.device).view(1, 2, 1, 1, 1) predict = predict * offsets predict, _ = predict.max(dim=1) else: predict = torch.argmax(seg, 1) loss_contrast += self.contrast_criterion( embedding, labels=target, predict=predict, queue=queue, ) else: loss_contrast += 0 if with_embed: return loss + self.loss_weight * loss_contrast return loss + 0 * loss_contrast # just a trick to avoid errors in distributed training
[docs] class PixelContrastLoss(nn.Module): """ Pixel Contrastive Loss for semantic segmentation tasks. Supports hard anchor sampling and negative sampling for contrastive learning. """ def __init__( self, temperature: float = 0.07, base_temperature: float = 0.07, ignore_index: int = -1, max_samples: int = 1024, max_views: int = 1, ndim: int = 2, ): """ Initialize the Pixel Contrastive Loss for semantic segmentation tasks. Parameters ---------- temperature : float, optional Temperature parameter for the contrastive loss. Default is 0.07. base_temperature : float, optional Base temperature for the contrastive loss. Default is 0.07. ignore_index : int, optional Label to ignore in the loss calculation. Default is -1. max_samples : int, optional Maximum number of samples to consider for the contrastive loss. Default is 1024. max_views : int, optional Maximum number of views to consider for the contrastive loss. Default is 1. ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. Default is 2. """ super(PixelContrastLoss, self).__init__() self.temperature = temperature self.base_temperature = base_temperature self.ignore_index = ignore_index self.max_samples = max_samples self.max_views = max_views self.ndim = ndim def _hard_anchor_sampling( self, X: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample hard anchors from the input features and their corresponding labels. Parameters ---------- X : torch.Tensor Input features of shape (batch_size, num_samples, feature_dim). E.g. (2, 32768, 256) for a batch size of 2, 32768 samples and 256 features. y_hat : torch.Tensor Ground truth labels of shape (batch_size, num_samples). E.g. (2, 32768) for a batch size of 2 and 32768 samples. y : torch.Tensor Predicted labels of shape (batch_size, num_samples). E.g. (2, 32768) for a batch size of 2 and 32768 samples. Returns ------- X_ : torch.Tensor Sampled features of shape (total_classes, self.max_views, feature_dim). E.g. (82, 1, 256) for 82 classes (this can vary depeding on the classes found in the ground truth), and 256 features. y_ : torch.Tensor Sampled labels of shape (total_classes,). E.g. (82,) for 82 classes (this can vary depeding on the classes found in the ground truth). """ batch_size, feat_dim = X.shape[0], X.shape[-1] classes = [] total_classes = 0 for ii in range(batch_size): this_y = y_hat[ii] this_classes = torch.unique(this_y) this_classes = [x for x in this_classes if x != self.ignore_index] this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views] classes.append(this_classes) total_classes += len(this_classes) if total_classes == 0: return None, None # type: ignore n_view = self.max_samples // total_classes n_view = min(n_view, self.max_views) X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float).cuda() y_ = torch.zeros(total_classes, dtype=torch.float).cuda() X_ptr = 0 for ii in range(batch_size): this_y_hat = y_hat[ii] this_y = y[ii] this_classes = classes[ii] for cls_id in this_classes: hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero() easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero() num_hard = hard_indices.shape[0] num_easy = easy_indices.shape[0] if num_hard >= n_view / 2 and num_easy >= n_view / 2: num_hard_keep = n_view // 2 num_easy_keep = n_view - num_hard_keep elif num_hard >= n_view / 2: num_easy_keep = num_easy num_hard_keep = n_view - num_easy_keep elif num_easy >= n_view / 2: num_hard_keep = num_hard num_easy_keep = n_view - num_hard_keep else: raise Exception("this shoud be never touched! {} {} {}".format(num_hard, num_easy, n_view)) perm = torch.randperm(num_hard) hard_indices = hard_indices[perm[:num_hard_keep]] perm = torch.randperm(num_easy) easy_indices = easy_indices[perm[:num_easy_keep]] indices = torch.cat((hard_indices, easy_indices), dim=0) X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1) y_[X_ptr] = cls_id X_ptr += 1 return X_, y_ def _sample_negative(self, Q: torch.Tensor): """ Sample negative examples from the queue. The queue is expected to be of shape (class_num, cache_size, feat_size), where: - class_num is the number of classes, - cache_size is the number of samples per class, - feat_size is the size of the feature vector. Parameters ---------- Q : torch.Tensor Queue of shape (class_num, cache_size, feat_size). E.g. (2, 60, 256) for 2 classes, 60 samples per class and 256 features. Returns ------- X_ : torch.Tensor Sampled negative examples of shape (class_num * cache_size, feat_size). E.g. (120, 256) for 2 classes, 60 samples per class and 256 features. y_ : torch.Tensor """ class_num, cache_size, feat_size = Q.shape X_ = torch.zeros((class_num * cache_size, feat_size)).float().cuda() y_ = torch.zeros((class_num * cache_size, 1)).float().cuda() sample_ptr = 0 for ii in range(class_num): if ii == 0: continue this_q = Q[ii, :cache_size, :] X_[sample_ptr : sample_ptr + cache_size, ...] = this_q y_[sample_ptr : sample_ptr + cache_size, ...] = ii sample_ptr += cache_size return X_, y_ def _contrastive( self, X_anchor: torch.Tensor, y_anchor: torch.Tensor, queue: Optional[torch.Tensor] = None, ): """ Contrastive loss calculation. Parameters ---------- X_anchor : torch.Tensor Anchor features of shape (total_classes, self.max_views, feature_dim). E.g. (82, 1, 256) for 82 classes (this can vary depeding on the classes found in the ground truth), and 256 features. y_anchor : torch.Tensor Anchor labels of shape (total_classes,). E.g. (82,) for 82 classes (this can vary depeding on the classes found in the ground truth). queue : torch.Tensor, optional Queue of negative examples of shape (class_num, cache_size, feat_size). E.g. (19, 10000, 256) for 19 classes, 10000 samples per class and 256 features. If not provided, the contrastive loss will be calculated using the anchor features only. """ anchor_num, n_view = X_anchor.shape[0], X_anchor.shape[1] y_anchor = y_anchor.contiguous().view(-1, 1) anchor_count = n_view anchor_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0) if queue is not None: X_contrast, y_contrast = self._sample_negative(queue) y_contrast = y_contrast.contiguous().view(-1, 1) contrast_count = 1 contrast_feature = X_contrast else: y_contrast = y_anchor contrast_count = n_view contrast_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0) mask = torch.eq(y_anchor, y_contrast.T).float().cuda() anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) logits = anchor_dot_contrast - logits_max.detach() mask = mask.repeat(anchor_count, contrast_count) neg_mask = 1 - mask logits_mask = torch.ones_like(mask).scatter_(1, torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(), 0) mask = mask * logits_mask neg_logits = torch.exp(logits) * neg_mask neg_logits = neg_logits.sum(1, keepdim=True) exp_logits = torch.exp(logits) log_prob = logits - torch.log(exp_logits + neg_logits) mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos loss = loss.mean() return loss
[docs] def forward( self, feats: torch.Tensor, labels: torch.Tensor, predict: torch.Tensor, queue: Optional[torch.Tensor] = None, ): """ Forward pass of the Pixel Contrastive Loss. Parameters ---------- feats : torch.Tensor Input features of shape (batch_size, feat_size, H, W) or (batch_size, feat_size, D, H, W). E.g. (2, 256, 128, 256) for a batch size of 2, 256 features and a spatial size of 128x256. labels : torch.Tensor Ground truth labels of shape (batch_size, C, H, W) or (batch_size, C, D, H, W). E.g. (2, 1, 128, 256) for a batch size of 2, 1 channel and a spatial size of 128x256. predict : torch.Tensor Predicted labels of shape (batch_size, H, W) or (batch_size, D, H, W). E.g. (2, 128, 256) for a batch size of 2 and a spatial size of 128x256. queue : torch.Tensor, optional Queue of negative examples of shape (class_num, cache_size, feat_size). E.g. (2, 60, 256) for 2 classes, 60 samples per class and 256 features. If not provided, the contrastive loss will be calculated using the anchor features only. Returns ------- loss : torch.Tensor Contrastive loss value. """ labels = torch.nn.functional.interpolate(labels.float().clone(), feats.shape[-self.ndim :], mode="nearest") # When working in instance segmentation the channels are more than 1 so we need to merge then into # just one channel. if labels.shape[1] != 1: if labels.ndim == 4: offsets = torch.tensor([1, 2], device=labels.device).view(1, 2, 1, 1) else: offsets = torch.tensor([1, 2], device=labels.device).view(1, 2, 1, 1, 1) labels = labels * offsets labels, _ = labels.max(dim=1) # In semantic the target is already compressed into one channel else: labels = labels.squeeze(1) labels = labels.long() assert labels.shape[-1] == feats.shape[-1], "Labels ({}) and features ({}) does not match in shape".format( labels.shape, feats.shape ) batch_size = feats.shape[0] labels = labels.contiguous().view(batch_size, -1) predict = predict.contiguous().view(batch_size, -1) if feats.ndim == 4: feats = feats.permute(0, 2, 3, 1) else: feats = feats.permute(0, 2, 3, 4, 1) feats = feats.contiguous().view(feats.shape[0], -1, feats.shape[-1]) feats_, labels_ = self._hard_anchor_sampling(feats, labels, predict) if feats_ is not None and labels_ is not None: loss = self._contrastive(feats_, labels_, queue=queue) return loss else: return 0
[docs] class instance_segmentation_loss: """ Custom loss for instance segmentation tasks in BiaPy. This loss combines different loss functions (e.g., BCE, L1, CrossEntropy) for multiple output channels, such as binary masks, contours, distances, and class channels. It supports class rebalancing, masking of distance channels, and different instance segmentation output types (e.g., "regular", "synapses"). The loss is configurable for various output channel combinations and can handle multi-class and multi-head settings. Parameters ---------- channel_weights : tuple of float, optional Weights to be applied to each output channel loss. E.g. (1, 0.2). out_channels : List of str, optional String specifying the output channels (e.g., ["F", "C"], ["B", "C", "P"], ["B","C","D"], etc.). losses_to_use : list of str, optional List of loss functions to use for each output channel (e.g., ["BCE", "MSE"]). channel_extra_opts : dict, optional Additional options for each output channel (e.g., {"D": {"mask_values": True}}). gt_channels_expected : int, optional Number of channels expected in the ground truth (default: 1). class_rebalance : str, optional Whether to reweight classes (inside loss function) or not. Options are: "none" and "auto". class_weights : List[float], optional Weights for each class to be used in the loss calculation (default: None). ignore_index : int, optional Value to ignore in the loss calculation (default: -1). """ def __init__( self, channel_weights=(1, 1), ndim: int = 2, class_rebalance_within_channels: bool = False, separated_class_channel: bool = False, out_channels=["F", "C"], losses_to_use=[], channel_extra_opts={}, gt_channels_expected: int = 1, class_rebalance: str = "none", class_weights: List[float] = [], ignore_index: int = -1, device = None, ): """ Initialize the custom loss that mixed BCE and MSE depending on the ``out_channels`` variable. Parameters ---------- channel_weights : 2 float tuple, optional Weights to be applied to be applied to each channel of the data. E.g. if working with F + C channels, you can provide for example these weights: ``(1, 2)`` so the contours will have more importance in the loss calculation. ndim : int, optional Number of dimensions of the input data. 2 for 2D images, 3 for 3D volumes. class_rebalance_within_channels : bool, optional Whether to apply a rebalancing strategy to the loss function to give more importance to underrepresented pixels within the channels. The weights are calculated automatically based on the number of pixels of each class. In the specific case of detection, where there are usually much less pixels representing the center of the objects to detect than background pixels, with this option activated, the loss will give more importance to the pixels representing the center of the objects to help the model learn better to predict them. separated_class_channel : bool, optional When a separated class channel is expected in the predictions e.g. instances + classification in instance segmentation. out_channels : List of str, optional Channels to operate with. losses_to_use : list of str, optional List of loss functions to use for each output channel (e.g., ["ce", "bce", "mae"]). channel_extra_opts : dict, optional Additional options for each output channel (e.g., {"B": {"mask_values": True}}). gt_channels_expected : int, optional Number of channels expected in the ground truth (default: 1). This is used to check that the GT loaded has the expected number of channels. If ``extra_weight_in_borders`` is ``True``, then 1 channel will be added to the expected GT channels to account for the extra weight in borders channel. class_rebalance: str, optional Whether to reweight classes or not. This only works if ``separated_class_channel`` is ``True``. Options are: "none" and ``"manual"``. class_weights : list of float, optional List of weights for each class to be used in ``"manual"`` class rebalancing. This only works when ``separated_class_channel`` is ``True`` and ``class_rebalance`` is ``"manual"``. E.g. ``[1, 1.7, 0.5]`` for 3 classes. ignore_index : int, optional Value to ignore in the loss calculation. If not provided, no value will be ignored. device : Torch device, optional Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps". """ if separated_class_channel: if class_rebalance == "manual" and not class_weights: raise ValueError("class_weights must be provided when class_rebalance is 'manual'") self.channel_weights = channel_weights self.class_rebalance_within_channels = class_rebalance_within_channels self.separated_class_channel = separated_class_channel self.ndim = ndim self.out_channels = [x for x in out_channels if x != "We"] self.extra_weight_in_borders = out_channels.count("We") > 0 self.gt_channels_expected = gt_channels_expected if not self.extra_weight_in_borders else gt_channels_expected + 1 self.channel_extra_opts = channel_extra_opts self.class_rebalance = class_rebalance self.class_weights = None self.ignore_index = ignore_index self.ignore_values = True if ignore_index != -1 else False self.losses_to_use = losses_to_use self.device = device if device is not None else torch.device("cpu") self.gamma = 0.5 if self.class_rebalance == "manual": self.class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32) if len(self.losses_to_use) != 0: assert len(self.out_channels) == len(self.losses_to_use), "Length of out_channels and losses_to_use should be the same. Provided {} and {}".format( self.out_channels, self.losses_to_use ) if self.separated_class_channel: self.class_channel_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index, weight=self.class_weights, reduction="none") def _foreground_mask(self, y_true: torch.Tensor) -> "Optional[torch.Tensor]": """Return a (B, 1, ...) float foreground mask from binary GT channels. Checks for F or M first (foreground=1), then B (background=1 → foreground=0). Returns None when no binary channel is present in the GT. """ for j, ch in enumerate(self.out_channels): if ch in ("F", "M"): return (y_true[:, j : j + 1] > 0).float() if ch == "B": return (y_true[:, j : j + 1] == 0).float() return None def __call__(self, y_pred, y_true): """ Calculate instance segmentation loss. Parameters ---------- y_true : torch.Tensor Ground truth masks. y_pred : torch.Tensor or list of Tensors Predictions. Returns ------- loss : torch.Tensor Loss value. """ if isinstance(y_pred, dict): _y_pred = y_pred["pred"] else: _y_pred = y_pred if self.separated_class_channel: assert isinstance(y_pred, dict), "When separated_class_channel is True, y_pred should be a dict with 'pred' and 'class' keys. Provided type: {}".format(type(y_pred)) assert "class" in y_pred, "When separated_class_channel is True, y_pred should be a dict with 'pred' and 'class' keys. Provided keys: {}".format(y_pred.keys()) _y_pred_class = y_pred["class"] assert (y_true.shape[1] == self.gt_channels_expected), ( "Seems that the GT loaded doesn't have {} channels as expected in {}. GT shape: {}".format( self.gt_channels_expected, self.out_channels, y_true.shape ) ) w_borders = None if self.extra_weight_in_borders: w_borders = y_true[:, -1] if not isinstance(_y_pred, list): _y_pred = [_y_pred] _y_pred_class = [_y_pred_class] if self.separated_class_channel else None inter_output_weights = [1.0] else: w = [self.gamma**i for i in range(len(_y_pred))] s = sum(w) inter_output_weights = [x / s for x in w] loss = 0 for idx, pd in enumerate(_y_pred): inter_output_loss = 0 for i, channel in enumerate(self.out_channels): pred_ch_start = self.out_channels.index(channel) gt_ch_start = pred_ch_start if channel == "A": pred_ch_end = pd.shape[1] gt_ch_end = pred_ch_end elif channel == "R": pred_ch_end = self.channel_extra_opts["R"].get("nrays", 32) + pred_ch_start gt_ch_end = pred_ch_end elif channel == "Db": val_type = self.channel_extra_opts.get("Db", {}).get("val_type", "norm") if val_type == "discretize": db_dis_bin_size = self.channel_extra_opts.get("Db", {}).get("bin_size", 0.1) db_dis_K = int(round(1.0 / db_dis_bin_size)) # 10 db_channels = db_dis_K + 1 else: db_channels = 1 pred_ch_end = pred_ch_start + db_channels gt_ch_end = pred_ch_start + 1 else: pred_ch_end = pred_ch_start + 1 gt_ch_end = pred_ch_end y_pred_slice = pd[:, pred_ch_start:pred_ch_end] y_true_slice = y_true[:, gt_ch_start:gt_ch_end].float() # element-wise mask for the loss mask_vals = self.channel_extra_opts.get(channel, {}).get("mask_values", False) mask = None if channel in ("Gv", "Gh", "Gz"): # Flow targets are unit vectors: magnitude is always 1 in the foreground # and (0, 0[, 0]) in the background. A foreground pixel with purely # horizontal flow has Gv=0, so per-component (!=0) masking would wrongly # exclude it. Sum of squared components is the correct foreground proxy. flow_channels = [j for j, ch in enumerate(self.out_channels) if ch in ("Gv", "Gh", "Gz")] mag_sq = sum(y_true[:, j : j + 1].float() ** 2 for j in flow_channels) mask = (mag_sq > 0).float() elif channel in ("H", "V", "Z"): # HoVer-Net channels: values in [-1, 1] (signed), centroid = 0 = background. # Masking by (!=0) would exclude the entire center row/column of every cell. # Use a binary foreground channel (F/M/B) when present; otherwise train on # all pixels — background=0 is a valid target and the network learns to # predict 0 there naturally. mask = self._foreground_mask(y_true) elif channel in ("Db", "Dc", "Dn", "R"): # Distance channels where legitimate foreground pixels can have value 0: # Db: boundary pixels have Db=0 after per-cell normalization # Dc: the centroid pixel has Dc=0 (most important point) # Dn: isolated cells (no neighbor) have Dn=0 # R: ray distances near boundary approach 0 # Use a binary foreground channel (F/M/B) when available; fall back to # (y_true_slice > 0) which excludes background but also misses genuine # zero-valued foreground pixels (e.g. Dc at centroid, Db at boundary). mask = self._foreground_mask(y_true) if mask is None and mask_vals: mask = (y_true_slice > 0).float() elif mask_vals: mask = (y_true_slice != 0).float() if self.ignore_values: mask = mask * (y_true_slice != self.ignore_index).float() if y_pred_slice.shape[-self.ndim :] != y_true_slice.shape[-self.ndim :]: y_true_slice = scale_target(y_true_slice, y_pred_slice.shape[-self.ndim :]) # class-rebalance / ignore_index weights for BCE weight = None if self.losses_to_use[i] in ["bce", "ce"] and channel in ["B","F","P","C","T","A","M","F_pre","F_post"]: if self.class_rebalance_within_channels: weight = weight_binary_ratio(y_true_slice).float() if self.ignore_values: ignore_mask = (y_true_slice != self.ignore_index).float() weight = ignore_mask if weight is None else weight * ignore_mask # instantiate criterion with no reduction so we can mask safely if self.losses_to_use[i] == "bce": crit = torch.nn.BCEWithLogitsLoss(weight=weight, reduction="none") elif self.losses_to_use[i] == "ce": crit = torch.nn.CrossEntropyLoss(weight=weight, reduction="none") y_true_slice = y_true_slice.long().squeeze(1) elif self.losses_to_use[i] in ["l1", "mae"]: crit = torch.nn.L1Loss(reduction="none") elif self.losses_to_use[i] == "mse": crit = torch.nn.MSELoss(reduction="none") else: raise ValueError("Loss function {} not recognized".format(self.losses_to_use[i])) if self.losses_to_use[i] != "ce": y_pred_slice = y_pred_slice.float() y_true_slice = y_true_slice.float() loss_tensor = crit(y_pred_slice, y_true_slice) # same shape as slice # multiply by spatial border weights after crit if w_borders is not None: loss_tensor = loss_tensor * w_borders # apply optional element mask AFTER computing the per-element loss if mask is not None: loss_tensor = loss_tensor * mask denom = mask.sum().clamp_min(1.0) else: denom = torch.tensor(loss_tensor.numel(), device=loss_tensor.device, dtype=loss_tensor.dtype) channel_loss_val = loss_tensor.sum() / denom inter_output_loss += self.channel_weights[i] * channel_loss_val if self.separated_class_channel: loss_tensor = self.class_channel_loss(_y_pred_class[idx], y_true[:, -1].type(torch.long)) # multiply by spatial border weights after crit if w_borders is not None: loss_tensor = loss_tensor * w_borders # Apply always a mask to the class channel to consider only the pixels where there is an instance (non-zero in the GT) # for the loss calculation, otherwise the loss value can be very low and not contribute to the training. if mask is None: mask = (y_true[:, -1] != 0).float() loss_tensor = loss_tensor * mask denom = mask.sum().clamp_min(1.0) loss += self.channel_weights[-1] * (loss_tensor.sum() / denom) loss += inter_output_weights[idx] * inter_output_loss return loss
[docs] def detection_metrics( true_points, pred_points, true_classes=None, pred_classes=None, tolerance=10, resolution: List[int | float] = [1, 1, 1], bbox_to_consider=[], verbose=False, ) -> Tuple[Dict[str, float], pd.DataFrame, pd.DataFrame]: """ Calculate detection metrics (precision, recall, F1) for point-based object detection. Parameters ---------- true_points : List of list List containing coordinates of ground truth points. E.g. ``[[5,3,2], [4,6,7]]``. pred_points : 4D Tensor List containing coordinates of predicted points. E.g. ``[[5,3,2], [4,6,7]]``. true_classes : List of ints, optional Classes of each ground truth points. pred_classes : List of ints, optional Classes of each predicted points. tolerance : optional, int Maximum distance far away from a GT point to consider a point as a true positive. resolution : List of int/float Weights to be multiply by each axis. Useful when dealing with anysotropic data to reduce the distance value on the axis with less resolution. E.g. ``(1,1,0.5)``. bbox_to_consider : List of tuple/list, optional To not take into account during metric calculation to those points outside the bounding box defined with this variable. Order is: ``[[z_min, z_max], [y_min, y_max], [x_min, x_max]]``. For example, using an image of ``10x100x200`` to not take into account points on the first/last slices and with a border of ``15`` pixel for ``x`` and ``y`` axes, this variable could be defined as follows: ``[[1, 9], [15, 85], [15, 185]]``. verbose : bool, optional To print extra information. Returns ------- metrics : List of strings List containing precision, accuracy and F1 between the predicted points and ground truth. """ if len(bbox_to_consider) > 0: assert len(bbox_to_consider) == 3, "'bbox_to_consider' need to be of length 3" assert [len(x) == 2 for x in bbox_to_consider], ( "'bbox_to_consider' needs to be a list of " "two element array/tuple. E.g. [[1,1],[15,100],[10,200]]" ) if true_classes is not None and pred_classes is None: raise ValueError("'pred_classes' must be provided when 'true_classes' is set") if true_classes is not None and pred_classes is not None: if len(true_classes) != len(true_points): raise ValueError("'true_points' and 'true_classes' length must be the same") if len(pred_classes) != len(pred_points): raise ValueError("'pred_points' and 'pred_classes' length must be the same") class_metrics = True else: class_metrics = False _true = np.array(true_points, dtype=np.float32) if len(pred_points) > 0: _pred = np.array(pred_points, dtype=np.float32) else: _pred = np.zeros((0, 3), dtype=np.float32) TP, FP, FN = 0, 0, 0 tag = ["FN" for x in _true] fp_preds = list(range(1, len(_pred) + 1)) dis = [-1 for x in _true] pred_id_assoc = [-1 for x in _true] TP_not_considered = 0 if len(_true) > 0: # Multiply each axis for the its real value for i in range(len(resolution)): _true[:, i] *= resolution[i] _pred[:, i] *= resolution[i] # Create cost matrix distances = distance_matrix(_pred, _true) n_matched = min(len(_true), len(_pred)) costs = -(distances >= tolerance).astype(float) - distances / (2 * n_matched) pred_ind, true_ind = linear_sum_assignment(-costs) # Analyse which associations are below the tolerance to consider them TP for i in range(len(pred_ind)): # Filter out those point outside the defined bounding box consider_point = False if len(bbox_to_consider) > 0: point = true_points[true_ind[i]] if ( bbox_to_consider[0][0] <= point[0] <= bbox_to_consider[0][1] and bbox_to_consider[1][0] <= point[1] <= bbox_to_consider[1][1] and bbox_to_consider[2][0] <= point[2] <= bbox_to_consider[2][1] ): consider_point = True else: consider_point = True if distances[pred_ind[i], true_ind[i]] < tolerance: if consider_point: TP += 1 tag[true_ind[i]] = "TP" else: tag[true_ind[i]] = "NC" TP_not_considered += 1 fp_preds.remove(pred_ind[i] + 1) dis[true_ind[i]] = distances[pred_ind[i], true_ind[i]] pred_id_assoc[true_ind[i]] = pred_ind[i] + 1 if TP_not_considered > 0: print(f"{TP_not_considered} TPs not considered due to filtering") FN = len(_true) - TP - TP_not_considered # FP filtering FP_not_considered = 0 fp_tags = ["FP" for x in fp_preds] if len(bbox_to_consider) > 0: for i in range(len(fp_preds)): point = pred_points[fp_preds[i] - 1] if not ( bbox_to_consider[0][0] <= point[0] <= bbox_to_consider[0][1] and bbox_to_consider[1][0] <= point[1] <= bbox_to_consider[1][1] and bbox_to_consider[2][0] <= point[2] <= bbox_to_consider[2][1] ): FP_not_considered += 1 fp_tags[i] = "NC" print(f"{FP_not_considered} FPs not considered due to filtering") FP = len(fp_preds) - FP_not_considered # Create two dataframes with the GT and prediction points association made and another one with the FPs df, df_fp = None, None df_columns =[ "gt_id", "pred_id", "distance", "tag", "axis-0", "axis-1", "axis-2", "gt_class", "pred_axis-0", "pred_axis-1", "pred_axis-2", "pred_class", ] if len(_true) > 0: _true = np.array(true_points, dtype=np.float32) if len(pred_points) > 0: _pred = np.array(pred_points, dtype=np.float32) else: _pred = np.zeros((0, 3), dtype=np.float32) # Capture FP coords fp_coords = np.zeros((len(fp_preds), _pred.shape[-1])) pred_fp_class = [-1] * len(fp_preds) for i in range(len(fp_preds)): fp_coords[i] = _pred[fp_preds[i] - 1] if class_metrics: assert pred_classes is not None pred_fp_class[i] = int(pred_classes[fp_preds[i] - 1]) # Capture prediction coords pred_coords = np.zeros((len(pred_id_assoc), _pred.shape[-1]), dtype=np.float32) pred_class = [-1] * len(pred_id_assoc) if not class_metrics: true_classes = [-1] * len(pred_id_assoc) for i in range(len(pred_id_assoc)): if pred_id_assoc[i] != -1: pred_coords[i] = _pred[pred_id_assoc[i] - 1] if class_metrics: assert pred_classes is not None pred_class[i] = int(pred_classes[pred_id_assoc[i] - 1]) else: pred_coords[i] = [0] * _pred.shape[-1] df = pd.DataFrame( zip( list(range(1, len(_true) + 1)), pred_id_assoc, dis, tag, _true[..., 0], _true[..., 1], _true[..., 2], true_classes, # type: ignore pred_coords[..., 0], pred_coords[..., 1], pred_coords[..., 2], pred_class, ), # type: ignore columns=df_columns, ) df_fp = pd.DataFrame( zip( fp_preds, fp_coords[..., 0], fp_coords[..., 1], fp_coords[..., 2], fp_tags, pred_fp_class, ), columns=["pred_id", "axis-0", "axis-1", "axis-2", "tag", "pred_class"], ) try: precision = TP / (TP + FP) except: precision = 0 try: recall = TP / (TP + FN) except: recall = 0 try: F1 = 2 * ((precision * recall) / (precision + recall)) except: F1 = 0 if not class_metrics: if df is not None: df = df.drop(columns=["gt_class", "pred_class"]) if df_fp is not None: df_fp = df_fp.drop(columns=["pred_class"]) else: if df is not None: # Class metrics must be computed only on true detections (TP by distance), # not on all Hungarian associations. tp_df = df[df["tag"] == "TP"] if len(tp_df) > 0: TP_classes = int((tp_df["gt_class"] == tp_df["pred_class"]).sum()) FN_classes = int((tp_df["gt_class"] != tp_df["pred_class"]).sum()) else: TP_classes = 0 FN_classes = 0 else: TP_classes = 0 FN_classes = 0 try: precision_classes = TP_classes / (TP_classes + FN_classes) except: precision_classes = 0 try: recall_classes = TP_classes / (TP_classes + FN_classes) except: recall_classes = 0 try: F1_classes = 2 * ((precision_classes * recall_classes) / (precision_classes + recall_classes)) except: F1_classes = 0 if verbose: if len(bbox_to_consider) > 0: print( "Points in ground truth: {} ({} total but {} not considered), Points in prediction: {} " "({} total but {} not considered)".format( len(_true), len(true_points), TP_not_considered, len(_pred), len(pred_points), FP_not_considered, ) ) else: print("Points in ground truth: {}, Points in prediction: {}".format(len(_true), len(_pred))) print("True positives: {}, False positives: {}, False negatives: {}".format(int(TP), int(FP), int(FN))) if class_metrics: print("True positives (class): {}, False negatives (class): {}".format(int(TP_classes), int(FN_classes))) if not class_metrics: r_dict = { "Precision": precision, "Recall": recall, "F1": F1, "TP": int(TP), "FP": int(FP), "FN": int(FN), } else: r_dict = { "Precision": precision, "Recall": recall, "F1": F1, "TP": int(TP), "FP": int(FP), "FN": int(FN), "Precision (class)": precision_classes, "Recall (class)": recall_classes, "F1 (class)": F1_classes, "TP (class)": int(TP_classes), "FN (class)": int(FN_classes), } if df is None: if "gt_class" in df_columns: df_columns.remove("gt_class") if "pred_class" in df_columns: df_columns.remove("pred_class") df = pd.DataFrame(columns=df_columns) if df_fp is None: df_fp = pd.DataFrame(columns=["pred_id", "axis-0", "axis-1", "axis-2", "tag"]) return r_dict, df, df_fp
[docs] class SSIM_loss(torch.nn.Module): """SSIM loss using torchmetrics StructuralSimilarityIndexMeasure.""" def __init__(self, data_range, device): """ Initialize the SSIM_loss module. Parameters ---------- data_range : float The value range of the input images (e.g., 1.0 or 255). device : torch.device Device to use for computation. """ super(SSIM_loss, self).__init__() self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device, non_blocking=True)
[docs] def forward(self, input, target): """ Compute the SSIM loss. Parameters ---------- input : torch.Tensor or dict Predicted images. If a dict, expects the prediction under the "pred" key. target : torch.Tensor Ground truth images. Returns ------- loss : torch.Tensor 1 minus the SSIM value (so that lower is better). """ if isinstance(input, dict): input = input["pred"] return 1 - self.ssim(input, target)
[docs] class W_MAE_SSIM_loss(torch.nn.Module): """ Weighted combination of MAE and SSIM loss. This loss combines Mean Absolute Error (MAE) and Structural Similarity Index Measure (SSIM) for image regression tasks, allowing the user to balance pixel-wise and perceptual similarity. """ def __init__(self, data_range, device, w_mae=0.5, w_ssim=0.5): """ Initialize the W_MAE_SSIM_loss module. Parameters ---------- data_range : float The value range of the input images (e.g., 1.0 or 255). device : torch.device Device to use for computation. w_mae : float, optional Weight for the MAE loss component (default: 0.5). w_ssim : float, optional Weight for the SSIM loss component (default: 0.5). """ super(W_MAE_SSIM_loss, self).__init__() self.w_mae = w_mae self.w_ssim = w_ssim self.mse = torch.nn.L1Loss().to(device, non_blocking=True) self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device, non_blocking=True)
[docs] def forward(self, input, target): """ Compute the weighted sum of MAE and SSIM loss. Parameters ---------- input : torch.Tensor or dict Predicted images. If a dict, expects the prediction under the "pred" key. target : torch.Tensor Ground truth images. Returns ------- loss : torch.Tensor Weighted sum of MAE and (1 - SSIM) loss. """ if isinstance(input, dict): input = input["pred"] return (self.mse(input, target) * self.w_mae) + ((1 - self.ssim(input, target)) * self.w_ssim)
[docs] class W_MSE_SSIM_loss(torch.nn.Module): """ Weighted combination of MSE and SSIM loss. This loss combines Mean Squared Error (MSE) and Structural Similarity Index Measure (SSIM) for image regression tasks, allowing the user to balance pixel-wise and perceptual similarity. """ def __init__(self, data_range, device, w_mse=0.5, w_ssim=0.5): """ Initialize the W_MSE_SSIM_loss module. Parameters ---------- data_range : float The value range of the input images (e.g., 1.0 or 255). device : torch.device Device to use for computation. w_mse : float, optional Weight for the MSE loss component (default: 0.5). w_ssim : float, optional Weight for the SSIM loss component (default: 0.5). """ super(W_MSE_SSIM_loss, self).__init__() self.w_mse = w_mse self.w_ssim = w_ssim self.mse = torch.nn.MSELoss().to(device, non_blocking=True) self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device, non_blocking=True)
[docs] def forward(self, input, target): """ Compute the weighted sum of MSE and SSIM loss. Parameters ---------- input : torch.Tensor or dict Predicted images. If a dict, expects the prediction under the "pred" key. target : torch.Tensor Ground truth images. Returns ------- loss : torch.Tensor Weighted sum of MSE and (1 - SSIM) loss. """ if isinstance(input, dict): input = input["pred"] return (self.mse(input, target) * self.w_mse) + ((1 - self.ssim(input, target)) * self.w_ssim)
[docs] def n2v_loss_mse(y_pred, y_true): """ Noise2Void MSE loss for self-supervised denoising. Parameters ---------- y_pred : torch.Tensor or dict Predicted output. y_true : torch.Tensor Ground truth and mask. Returns ------- loss : torch.Tensor Loss value. """ if isinstance(y_pred, dict): y_pred = y_pred["pred"] target = y_true[:, :y_pred.shape[1]] mask = y_true[:, y_pred.shape[1]:] loss = torch.sum(torch.square(target - y_pred * mask)) / torch.sum(mask) return loss
[docs] class SSIM_wrapper: """Wrapper for SSIM loss using pytorch_msssim.""" def __init__(self): """Initiate wrapper to SSIM loss function.""" self.loss = SSIM(data_range=1, size_average=True, channel=1) def __call__(self, y_pred, y_true): """ Calculate instance segmentation loss. Parameters ---------- y_true : torch.Tensor Ground truth masks. y_pred : torch.Tensor or list of Tensors Predictions. Returns ------- loss : torch.Tensor Loss value. """ if isinstance(y_pred, dict): y_pred = y_pred["pred"] return 1 - self.loss(y_pred, y_true)
[docs] def lovasz_hinge(logits: torch.Tensor, labels: torch.Tensor, per_image: bool = True, ignore_index: int | None = None) -> torch.Tensor: """ Single-function binary Lovász hinge loss. - logits: unnormalized scores, same shape as labels (e.g. (1,H,W) or (1,D,H,W)) - labels: {0,1} or bool tensor, same shape as logits - per_image: average loss per-item if a batch dim exists - ignore_index: label value to ignore (optional) """ if logits.shape != labels.shape: raise ValueError(f"Shape mismatch: logits {logits.shape} vs labels {labels.shape}") # Handle per-image averaging if a batch dim is present if per_image and logits.dim() >= 2 and logits.size(0) > 1: losses = [] for li, yi in zip(logits, labels): # Flatten and optionally filter ignore_index l_flat = li.reshape(-1) y_flat = yi.to(dtype=torch.long, device=li.device).reshape(-1) if ignore_index is not None: valid = (y_flat != ignore_index) l_flat = l_flat[valid] y_flat = y_flat[valid] if l_flat.numel() == 0: continue # Signs in {-1,+1}, hinge errors, sort desc signs = y_flat.float() * 2 - 1 errors = 1 - l_flat * signs errs_sorted, perm = torch.sort(errors, descending=True) y_sorted = y_flat[perm].float() # Lovász gradient (Jaccard) in-place, no helpers p = y_sorted.numel() if p == 0: continue gts = y_sorted.sum() inter = gts - y_sorted.cumsum(0) union = gts + (1 - y_sorted).cumsum(0) jacc = 1.0 - inter / torch.clamp_min(union, 1.0) if p > 1: jacc[1:p] = jacc[1:p] - jacc[0:p-1] losses.append(F.relu(errs_sorted) @ jacc) return (torch.stack(losses).mean() if len(losses) else logits.new_tensor(0.0)) # Single item (or per_image=False): same steps without the loop l_flat = logits.reshape(-1) y_flat = labels.to(dtype=torch.long, device=logits.device).reshape(-1) if ignore_index is not None: valid = (y_flat != ignore_index) l_flat = l_flat[valid] y_flat = y_flat[valid] if l_flat.numel() == 0: return logits.new_tensor(0.0) signs = y_flat.float() * 2 - 1 errors = 1 - l_flat * signs errs_sorted, perm = torch.sort(errors, descending=True) y_sorted = y_flat[perm].float() p = y_sorted.numel() gts = y_sorted.sum() inter = gts - y_sorted.cumsum(0) union = gts + (1 - y_sorted).cumsum(0) jacc = 1.0 - inter / torch.clamp_min(union, 1.0) if p > 1: jacc[1:p] = jacc[1:p] - jacc[0:p-1] return F.relu(errs_sorted) @ jacc
[docs] class SpatialEmbLoss(nn.Module): """ Spatial Embedding Loss for 2D and 3D inspired by `EmbedSeg <https://github.com/juglab/EmbedSeg/tree/main>`__. Parameters ---------- patch_size : List of int, optional Patch size used during training (used to build coordinate map buffer). anisotropy : List of float or int, optional Anisotropy factors for each axis (z,y,x). ndims : int, optional Number of spatial dimensions (2 or 3). center_mode : str, optional Method to compute object center: "centroid" or "medoid". medoid_max_points : int, optional Maximum number of points to use when computing medoid (to avoid O(N^2) complexity). channel_weights : List of float, optional Weights for the different loss components: [foreground, instance, variance, seed]. """ def __init__( self, patch_size: List[int] = [32, 1024, 1024], anisotropy: List[float | int] = [1,1,1], ndims: int = 2, center_mode: str = "centroid", # "centroid" or "medoid" medoid_max_points: Optional[int] = 10000, # cap to avoid O(N^2) on huge objects channel_weights: List[float] = [1.0, 1.0, 1.0], ): super().__init__() self.ndims = ndims self.center_mode = center_mode self.medoid_max_points = medoid_max_points # Grid sizes (used to build the coordinate map buffer; sliced to input size on forward) grid_z = patch_size[0] if ndims == 3 else 1 grid_y = patch_size[-3] grid_x = patch_size[-2] # Pixel sizes (coordinate extents) pixel_z = anisotropy[0] if ndims == 3 else 1 pixel_y = anisotropy[1] pixel_x = anisotropy[2] self.channel_weights = channel_weights self.foreground_weight = self.channel_weights[0] self.w_inst = self.channel_weights[1] self.w_var = self.channel_weights[2] self.w_seed = self.channel_weights[3] # Build max-size 3D coordinate grid buffer; for 2D we will slice z=1. # This lets one class handle both 2D (uses x,y) and 3D (uses x,y,z). xm = ( torch.linspace(0, pixel_x, grid_x) .view(1, 1, 1, -1) .expand(1, grid_z, grid_y, grid_x) ) ym = ( torch.linspace(0, pixel_y, grid_y) .view(1, 1, -1, 1) .expand(1, grid_z, grid_y, grid_x) ) zm = ( torch.linspace(0, pixel_z, grid_z) .view(1, -1, 1, 1) .expand(1, grid_z, grid_y, grid_x) ) # Stack as (3, Z, Y, X); for 2D we’ll slice to (2, Y, X) at forward time. xyzm = torch.cat((xm, ym, zm), 0) # (3, Z, Y, X) self.register_buffer("xyzm", xyzm) def _calculate_binary_iou(self, pred, label): intersection = ((label == 1) & (pred == 1)).sum() union = ((label == 1) | (pred == 1)).sum() if not union: return 0 else: iou = intersection.item() / union.item() return iou @torch.no_grad() def _center_from_mask( self, coords: torch.Tensor, # (D, ...) in_mask: torch.Tensor, # (1, ...) ) -> torch.Tensor: """ Compute object center from binary mask using centroid or medoid. Parameters ---------- coords : torch.Tensor Coordinate grid tensor of shape (D, ...), where D is the number of spatial dimensions in_mask : torch.Tensor Binary mask tensor of shape (1, ...), indicating the object pixels/voxels. Returns ------- center : torch.Tensor Computed center coordinates of shape (D, 1, ..., 1). """ D = coords.size(0) # Extract coordinates of all pixels/voxels in the instance: (N, D) pts = coords[in_mask.expand_as(coords)].view(D, -1).t().contiguous() # (N, D) if pts.numel() == 0: # No pixels: fall back to zeros return torch.zeros(D, *([1] * (coords.dim() - 1)), device=coords.device, dtype=coords.dtype) if self.center_mode == "centroid" or pts.shape[0] == 1: c = pts.mean(0) # (D,) else: # MEDOID: minimize sum of Euclidean distances to all other points # Optionally sub-sample to keep cdist tractable if self.medoid_max_points is not None and pts.shape[0] > self.medoid_max_points: idx = torch.randperm(pts.shape[0], device=pts.device)[: self.medoid_max_points] pts_sub = pts[idx] dist = torch.cdist(pts_sub, pts_sub, p=2) # (M, M) sums = dist.sum(dim=1) best = torch.argmin(sums) c = pts_sub[best] # approximate medoid else: dist = torch.cdist(pts, pts, p=2) # (N, N) sums = dist.sum(dim=1) best = torch.argmin(sums) c = pts[best] # exact medoid return c.view(D, *([1] * (coords.dim() - 1))) # (D, 1, ..., 1)
[docs] def forward( self, prediction: torch.Tensor, # (B, C, H, W) or (B, C, D, H, W) instances: torch.Tensor, # (B, H, W) or (B, D, H, W) ) -> Tuple[torch.Tensor, float, str]: if prediction.dim() not in (4, 5): raise ValueError( f"Unsupported prediction tensor dimensionality {prediction.dim()}. " "Expected 4D (B,C,H,W) or 5D (B,C,D,H,W)." ) B = prediction.size(0) D = prediction.dim() - 2 # number of spatial dims (2 or 3) assert D in (2, 3), "Only 2D or 3D supported" assert D == self.ndims, f"Model ndims={D} does not match loss ndims={self.ndims}" # Spatial sizes if D == 2: H, W = prediction.size(2), prediction.size(3) # coords: (2, H, W) from self.xyzm (3, Z, Y, X) coords = self.xyzm[:2, 0, :H, :W].contiguous() total_voxels = H * W else: Z, H, W = prediction.size(2), prediction.size(3), prediction.size(4) # coords: (3, Z, H, W) coords = self.xyzm[:3, :Z, :H, :W].contiguous() total_voxels = Z * H * W # Remove the extra channel dimension in instances instances = instances[:, 0] # Channel partition emb_ch = D # 2 for 2D, 3 for 3D sig_ch = self.ndims # equal to D seed_ch = 1 loss = prediction.new_tensor(0.0) for b in range(B): # Spatial embedding (tanh) + coordinate grid spatial_emb = torch.tanh(prediction[b, :emb_ch]) + coords # (D, ...) sigma = prediction[b, emb_ch:emb_ch + sig_ch] # (D, ...) seed_map = torch.sigmoid( prediction[b, emb_ch + sig_ch: emb_ch + sig_ch + seed_ch] ) # (1, ...) var_loss = prediction.new_tensor(0.0) instance_loss = prediction.new_tensor(0.0) seed_loss = prediction.new_tensor(0.0) iou = prediction.new_tensor(0.0) obj_count = 0 instance = instances[b].unsqueeze(0) # (1, ...) instance_ids = instance.unique() instance_ids = instance_ids[instance_ids != 0] # Regress background seeds to zero bg_mask = (instances[b] == 0).unsqueeze(0) # (1, ...) if bg_mask.sum() > 0: seed_loss = seed_loss + torch.sum(torch.pow(seed_map[bg_mask] - 0, 2)) for idv in instance_ids: in_mask = instance.eq(idv) # (1, ...) center = self._center_from_mask(coords, in_mask) # Sigma stats on object pixels/voxels sigma_in = sigma[in_mask.expand_as(sigma)].view(sig_ch, -1) # (D, N) s_mean = sigma_in.mean(1).view(sig_ch, 1) # (D, 1) # Variance loss (before exp), detaching mean to match originals var_loss = var_loss + torch.mean(torch.pow(sigma_in - s_mean.detach(), 2)) # Distance field s = torch.exp(s_mean.view(sig_ch, *([1] * D)) * 10) # (D, 1...1) dist = torch.exp( -1 * torch.sum(torch.pow(spatial_emb - center, 2) * s, dim=0, keepdim=True) ) # (1, ...) # Instance (Lovász hinge) loss on the soft mask instance_loss = instance_loss + lovasz_hinge(dist * 2 - 1, in_mask) # Seed regression loss towards distance field (fg only) seed_loss = seed_loss + self.foreground_weight * torch.sum( torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2) ) # Measure IoU at 0.5 threshold iou += self._calculate_binary_iou(dist > 0.5, in_mask) obj_count += 1 if obj_count > 0: instance_loss = instance_loss / obj_count var_loss = var_loss / obj_count iou = iou / obj_count seed_loss = seed_loss / total_voxels loss = loss + (self.w_inst * instance_loss + self.w_var * var_loss + self.w_seed * seed_loss) loss = loss / B iou = iou / B return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals