Source code for biom3d.metrics

"""
Metrics/losses.

Mostly for segmentation.
"""

import torch 
from torch import Type, nn 
import torch.nn.functional as F
import numpy as np

from abc import abstractmethod
from typing import Any, Callable, Literal, Optional

#---------------------------------------------------------------------------
# Metrics base class
[docs] class Metric(nn.Module): """ Abstract base class for all metrics. Designed to store and update metric values such as the current value, average, sum, and count. In the biom3d structure, all metrics must implement the following: - a `name` attribute for identification, - a `reset()` method to clear internal states, - a `forward()` method to compute the metric given predictions and targets, - an `update()` method to accumulate statistics. To define a new metric: 1. Inherit from `Metric`. 2. Set the `name` attribute in `__init__`. 3. Implement `forward()` to compute the metric and assign to `self.val`. :ivar str name: Name of the metric (used in logging/display). :ivar float val: Current value of the metric for the last batch. :ivar float avg: Running average of the metric. :ivar float sum: Cumulative sum of the metric across all updates. :ivar int count: Number of updates applied (used to compute average). """
[docs] def __init__(self, name:Optional[str]=None): """ Initialize the base Metric class. Parameters ---------- name : str, optional Name of the metric, for display/logging purposes. """ super(Metric, self).__init__() self.name = name self.reset()
[docs] def reset(self) -> None: """ Reset internal metric statistics. Returns ------- None """ self.val = 0.0 self.avg = 0.0 self.sum = 0.0 self.count = 0
[docs] @abstractmethod def forward(self, preds: torch.Tensor, trues: torch.Tensor) -> None: """ Compute the metric value for the given predictions and targets. Must be implemented by subclass. Parameters ---------- preds : torch.Tensor Model predictions. trues : torch.Tensor Ground truth targets. Returns ------- None """ pass
[docs] def update(self, n: int = 1) -> None: """ Update metric statistics with the current value. Parameters ---------- n : int, default=1 Number of samples in the current batch. Returns ------- None """ with torch.no_grad(): self.sum += self.val * n self.count += n self.avg = self.sum / self.count
[docs] def str(self) -> str: """ Return string representation of the current value. Returns ------- str String representation of the current . """ return str(self.val)
def __str__(self): """ Return formatted string representation of the average or current value. Returns ------- str Formatted metric string: "<name> <value>". """ to_print = self.avg if self.avg!=0 else self.val return "{} {:.3f}".format(self.name, to_print)
#--------------------------------------------------------------------------- # Metrics/losses for semantic segmentation
[docs] class Dice(Metric): """ Dice score computation metric. Computes the Dice coefficient between predictions and targets. Supports binary and multi-class segmentation with optional softmax normalization. Background channel can be automatically removed for multi-class setups. :ivar str name: Name of the metric (used in logs). :ivar bool use_softmax: Whether to apply softmax before Dice computation. :ivar tuple[int] dim: Dimensions over which Dice is computed (e.g., (2, 3) or (2, 3, 4)). """
[docs] def __init__(self, use_softmax:bool=False, dim:tuple[int]=(), name:str=None): """ Initialize the Dice metric. Parameters ---------- use_softmax : bool, default=False Whether to apply softmax to model outputs. If True, background channel is removed. dim : tuple, default=() Dimensions along which Dice is computed. Use (2,3) for 2D or (2,3,4) for 3D images. name : str, optional Name of the metric for logging and display. """ super(Dice, self).__init__() self.name = name self.dim = dim self.use_softmax = use_softmax # if use softmax then remove bg
[docs] def forward(self, inputs:torch.Tensor, targets:torch.Tensor, smooth:float=1.0, )->torch.Tensor: """ Compute Dice coefficient between predictions and ground truth targets. Parameters ---------- inputs : torch.Tensor Model outputs (logits or probabilities). targets : torch.Tensor Ground truth segmentation masks. smooth : float, default=1.0 Smoothing constant to avoid division by zero. Returns ------- self.val: torch.Tensor Dice score as a scalar tensor. """ if self.use_softmax: # for dice computation, remove the background and flatten inputs = inputs.softmax(dim=1) if not all([i == j for i, j in zip(inputs.shape, targets.shape)]): # if this is not the case then gt is probably not already a one hot encoding targets_oh = torch.zeros(inputs.shape, device=inputs.device) targets_oh.scatter_(1, targets.long(), 1) else: targets_oh = targets # remove background inputs = inputs[:,1:] targets = targets_oh[:,1:] else: inputs = inputs.sigmoid() intersection = (inputs * targets).sum(dim=self.dim) dice = (2.*intersection + smooth)/(inputs.sum(dim=self.dim) + targets.sum(dim=self.dim) + smooth) self.val = 1 - dice.mean() if self.training else dice.mean() return self.val
[docs] class DiceBCE(Metric): """ Dice + Binary Cross-Entropy (BCE) metric. This combined metric can also be used as a loss function. It computes the sum of the BCE loss and the Dice loss, supporting both binary and multi-class segmentation with optional softmax activation. For multi-class predictions, background is removed from Dice computation if `use_softmax` is True. :ivar str name: Name of the metric (used in logs). :ivar bool use_softmax: Whether to apply softmax and remove background for Dice computation. :ivar tuple[int] dim: Dimensions over which Dice is computed. :ivar torch.nn.CrossEntropyLoss bce: BCE loss function module. """
[docs] def __init__(self, use_softmax:bool=False, dim:tuple[int]=(), name:Optional[str]=None): """ Initialize the DiceBCE metric. Parameters ---------- use_softmax : bool, default=False Whether to apply softmax to model outputs. If True, background is excluded from Dice. dim : tuple, default=() Dimensions along which Dice is computed (e.g., (2, 3) or (2, 3, 4)). name : str, optional Name of the metric for logging and display. """ super(DiceBCE, self).__init__() self.use_softmax = use_softmax # if use softmax then remove bg for dice computation self.name = name self.dim = dim # axis defined for the dice score self.bce = torch.nn.CrossEntropyLoss(reduction='mean')
[docs] def forward(self, inputs:torch.Tensor, targets:torch.Tensor, smooth:float=1.0, )->torch.Tensor: """ Compute the combined Dice and BCE loss. Parameters ---------- inputs : torch.Tensor Model outputs (logits). targets : torch.Tensor Ground truth segmentation masks (can be one-hot or class indices). smooth : float, default=1.0 Smoothing constant to avoid division by zero in Dice computation. Returns ------- slef.val: torch.Tensor The Dice + BCE. """ #comment out if your model contains a sigmoid or equivalent activation layer if self.use_softmax: BCE = self.bce(inputs, targets.argmax(dim=1).long()) # for dice computation, remove the background and flatten inputs = inputs.softmax(dim=1) if not all([i == j for i, j in zip(inputs.shape, targets.shape)]): # if this is not the case then gt is probably not already a one hot encoding targets_oh = torch.zeros(inputs.shape, device=inputs.device) targets_oh.scatter_(1, targets.long(), 1) else: targets_oh = targets # remove background inputs = inputs[:,1:] targets = targets_oh[:,1:] else: # keep the background and flatten # inputs = inputs.reshape(-1) # targets = targets.reshape(-1) BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean') inputs = inputs.sigmoid() intersection = (inputs * targets).sum(dim=self.dim) dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum(dim=self.dim) + targets.sum(dim=self.dim) + smooth) Dice_BCE = BCE + dice_loss.mean() self.val = Dice_BCE return self.val
[docs] class IoU(Metric): """ Intersection over Union (IoU) score computation. This metric measures the overlap between the predicted and ground truth segmentation masks. Can be used as a metric or a loss function (when inverted during training). Supports both binary and multi-class segmentation tasks, with optional softmax activation and background removal. :ivar str name: Name of the metric. :ivar bool use_softmax: Whether to apply softmax (multi-class case). :ivar tuple dim: Dimensions along which the IoU is computed. """
[docs] def __init__(self, use_softmax: bool = False, dim: tuple = (), name: Optional[str] = None): """ Initialize the IoU metric. Parameters ---------- use_softmax : bool, default=False Whether to apply softmax to model outputs. If True, background is excluded from IoU computation. dim : tuple, default=() Dimensions over which IoU is computed (e.g., (2, 3, 4) for 3D or (2, 3) for 2D). name : str, optional Name of the metric (for logging or display). """ super(IoU, self).__init__() self.use_softmax = use_softmax # if use softmax then remove bg self.dim = dim self.name = name
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0, ) -> torch.Tensor: """ Compute the IoU between predicted and target masks. Parameters ---------- inputs : torch.Tensor Logits output by the model. targets : torch.Tensor Ground truth segmentation masks (either one-hot or class indices). smooth : float, default=1.0 Smoothing term to avoid division by zero. Returns ------- self.val: torch.Tensor Computed IoU score. """ if self.use_softmax: inputs = inputs.softmax(dim=1) if not all([i == j for i, j in zip(inputs.shape, targets.shape)]): # if this is not the case then gt is probably not already a one hot encoding targets_oh = torch.zeros(inputs.shape, device=inputs.device) targets_oh.scatter_(1, targets.long(), 1) else: targets_oh = targets # remove background inputs = inputs[:,1:] targets = targets_oh[:,1:] else: inputs = inputs.sigmoid() #intersection is equivalent to True Positive count #union is the mutually inclusive area of all labels & predictions intersection = (inputs * targets).sum(dim=self.dim) total = (inputs + targets).sum(dim=self.dim) union = total - intersection iou = (intersection + smooth)/(union + smooth) self.val = 1 - iou.mean() if self.training else iou.mean() return self.val
[docs] class MSE(Metric): """ Mean Squared Error (MSE) loss. This metric computes the average squared difference between predictions and targets. Commonly used as a regression loss, but can also serve as a metric in training. :ivar str name: Name of the metric (for logging or display). """
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute the mean squared error between predictions and targets. Parameters ---------- inputs : torch.Tensor Predicted values. targets : torch.Tensor Ground truth values. Returns ------- torch.Tensor Computed MSE loss. """ self.val = torch.nn.functional.mse_loss(inputs, targets, reduction='mean') return self.val
[docs] class CrossEntropy(Metric): """ Cross-entropy loss metric. This metric computes the average cross-entropy between predicted class scores and target class indices. Typically used for classification problems. :ivar str name: Name of the metric (used in logging or display). :ivar torch.nn.CrossEntropyLoss ce: Internal cross-entropy loss module. """
[docs] def __init__(self, name: Optional[str] = None): """ Initialize the CrossEntropy metric. Parameters ---------- name : str, optional Name of the metric, for display or logging purposes. """ super(CrossEntropy, self).__init__() self.name = name self.ce = torch.nn.CrossEntropyLoss(reduction='mean')
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute the cross-entropy loss between predictions and targets. Parameters ---------- inputs : torch.Tensor Raw class scores (logits) of shape (N, C, ...). targets : torch.Tensor Ground truth class indices of shape (N, ...). Returns ------- torch.Tensor Computed cross-entropy loss. """ self.val = self.ce(inputs, targets) return self.val
#--------------------------------------------------------------------------- # Metric adaptation for deep supervision
[docs] class DeepMetric(Metric): """ Metric wrapper for deep supervision. Applies a given metric across multiple feature maps (from different decoder depths, e.g., in U-Net), combining them using a weighted sum of coefficients (`alphas`). Each feature map prediction is compared to the same ground truth. :ivar Metric metric: Base metric applied at each level. :ivar list[float] alphas: Weights associated with each level’s output. :ivar str name: Name of the metric. """
[docs] def __init__(self, metric: type[Metric], alphas: list[float], name: Optional[str] = None, metric_kwargs: Optional[dict[str, Any]] = None): """ Initialize the DeepMetric. Parameters ---------- metric : type[Metric] A callable class or constructor of a `Metric` to apply at each level. alphas : list of float List of coefficients for each prediction level. Must match the number of inputs. name : str, optional Name of the metric (used for logging/display). metric_kwargs : dict, optional Additional keyword arguments for the base metric constructor. """ super(DeepMetric, self).__init__() self.metric = metric(**metric_kwargs) self.name = name self.alphas = alphas
[docs] def forward(self, inputs: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor: """ Compute the deep supervision metric across multiple prediction levels. Parameters ---------- inputs : list of torch.Tensor List of model outputs at various depths. targets : torch.Tensor Ground truth labels shared across all prediction levels. Returns ------- self.val: torch.Tensor Weighted sum of metric values across levels. """ # inputs must be a list of network output # they are here all compared to the targets # the last inputs is supposed to be the final one self.val = 0 for i in range(len(inputs)): if self.alphas[i]!=0: self.val += self.metric(inputs[i], targets)*self.alphas[i] return self.val
#--------------------------------------------------------------------------- # nnUNet metrics
[docs] def sum_tensor(inp:torch.Tensor, axes:int|tuple[int]|list[int], keepdim:bool=False, )->torch.Tensor: """ Sum a tensor over specified axes. Parameters ---------- inp : torch.Tensor Input tensor to reduce. axes : int or list/tuple of int Axes over which the tensor is summed. keepdim : bool, default=False Whether to retain reduced dimensions (with size 1). Returns ------- inp: torch.Tensor Reduced tensor with summed values. No copy. """ axes = np.unique(axes).astype(int) if keepdim: for ax in axes: inp = inp.sum(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): inp = inp.sum(int(ax)) return inp
[docs] def get_tp_fp_fn_tn(net_output: torch.Tensor, gt: torch.Tensor, axes: Optional[int| tuple[int]| list[int]] = None, mask: Optional[torch.Tensor] = None, square: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute true positives (TP), false positives (FP), false negatives (FN), and true negatives (TN) between network outputs and ground truth labels. Assumes the input `net_output` is in shape (B, C, ...) and `gt` is either a label map (B, ...) or one-hot encoded (B, C, ...). If not already in one-hot form, the function converts `gt`. Parameters ---------- net_output : torch.Tensor Model prediction tensor with shape (B, C, H, W, (D)). gt : torch.Tensor Ground truth tensor. Can be label map (B, H, W, (D)) or one-hot encoded (B, C, H, W, (D)). axes : int or list/tuple of int, optional Axes to reduce over (e.g., spatial dimensions). mask : torch.Tensor, optional Optional binary mask of shape (B, 1, H, W, (D)) where 1 indicates valid pixels. square : bool, default=False Whether to square the TP/FP/FN/TN tensors before summation. Returns ------- tp torch.Tensor True positives after reduction per class. fp torch.Tensor False positives after reduction per class. tn torch.Tensor True negatives after reduction per class. fn torch.Tensor False negatives after reduction per class. """ if axes is None: axes = tuple(range(2, len(net_output.size()))) shp_x = net_output.shape shp_y = gt.shape with torch.no_grad(): if len(shp_x) != len(shp_y): gt = gt.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(net_output.shape, gt.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: gt = gt.long() y_onehot = torch.zeros(shp_x, device=net_output.device) y_onehot.scatter_(1, gt, 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) fn = (1 - net_output) * y_onehot tn = (1 - net_output) * (1 - y_onehot) if mask is not None: tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) if square: tp = tp ** 2 fp = fp ** 2 fn = fn ** 2 tn = tn ** 2 if len(axes) > 0: tp = sum_tensor(tp, axes, keepdim=False) fp = sum_tensor(fp, axes, keepdim=False) fn = sum_tensor(fn, axes, keepdim=False) tn = sum_tensor(tn, axes, keepdim=False) return tp, fp, fn, tn
[docs] class SoftDiceLoss(nn.Module): """ Soft Dice loss implementation. This loss is useful for segmentation tasks, especially when dealing with class imbalance. It computes a soft version of the Dice coefficient, optionally over the batch dimension. Parameters ---------- apply_nonlin : callable, optional Optional non-linearity to apply to the prediction (e.g., torch.softmax or torch.sigmoid). batch_dice : bool, default=False If True, compute Dice over the entire batch instead of per-sample. do_bg : bool, default=True If False, ignore background channel (assumed to be channel index 0). smooth : float, default=1. Smoothing factor added to numerator and denominator to avoid division by zero. """
[docs] def __init__(self, apply_nonlin: Optional[Callable] = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.0): """ Initialize the SoftDiceLoss. Parameters ---------- apply_nonlin : callable, optional Non-linearity function to apply to the predictions (e.g., torch.sigmoid or torch.softmax). If None, no activation is applied. batch_dice : bool, default=False If True, computes Dice over the entire batch as a whole; otherwise computes per sample. do_bg : bool, default=True If False, excludes the background class (channel 0) from the Dice computation. smooth : float, default=1.0 Smoothing factor added to numerator and denominator to avoid division by zero. """ super(SoftDiceLoss, self).__init__() self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute the Soft Dice loss. Parameters ---------- x : torch.Tensor Network predictions of shape (B, C, ...) where B is batch size and C number of classes. y : torch.Tensor Ground truth labels, either label maps or one-hot encoded tensors. loss_mask : torch.Tensor, optional Mask tensor with shape (B, 1, ...), where valid pixels are 1 and invalid are 0. Returns ------- -dc: torch.Tensor Scalar loss value (negative mean Dice coefficient). """ shp_x = x.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) nominator = 2 * tp + self.smooth denominator = 2 * tp + fp + fn + self.smooth dc = nominator / (denominator + 1e-8) if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc
[docs] class RobustCrossEntropyLoss(nn.CrossEntropyLoss): """ Compatibility wrapper for CrossEntropyLoss to handle target tensors with an extra singleton dimension and float dtype. This class modifies the target tensor shape and type to fit the expected input of nn.CrossEntropyLoss, which requires targets to be LongTensor without extra dimensions. """
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute the cross-entropy loss with robust target handling. Parameters ---------- input : torch.Tensor Predictions (logits) of shape (B, C, ...) where B is batch size and C number of classes. target : torch.Tensor Target labels which might have an extra dimension (e.g. shape (B, 1, ...)) and float dtype. Returns ------- torch.Tensor Scalar loss value. """ if len(target.shape) == len(input.shape): assert target.shape[1] == 1 target = target[:, 0] return super().forward(input, target.long())
[docs] class DC_and_CE_loss(Metric): """ Combined Dice and Cross-Entropy loss metric. This metric combines Soft Dice loss and Cross Entropy loss, with optional weighting and masking support. Weights for CE and Dice do not need to sum to one and can be set independently. :ivar RobustCrossEntropyLoss ce: Cross entropy loss module. :ivar SoftDiceLoss dc: Soft Dice loss module. :ivar float weight_ce: Weight for the Cross Entropy loss. :ivar float weight_dice: Weight for the Dice loss. :ivar bool log_dice: Whether to log-transform the Dice loss. :ivar int | None ignore_label: Label to ignore during loss calculation. :ivar str aggregate: Method to aggregate losses (currently supports "sum" only). :ivar str name: Name of the metric. """
[docs] def __init__(self, soft_dice_kwargs:dict[str,Any], ce_kwargs:dict[str,Any], aggregate:Literal["sum"]="sum", square_dice:bool=False, weight_ce:float=1.0, weight_dice:float=1.0, log_dice:bool=False, ignore_label:Optional[int]=None, name:Optional[str]=None): """ Initialize the combined Dice and Cross Entropy loss metric. Parameters ---------- soft_dice_kwargs : dict of str to any Keyword arguments for SoftDiceLoss. ce_kwargs : dict of str to any Keyword arguments for RobustCrossEntropyLoss. aggregate : "sum", default="sum" Aggregation method for losses. Only "sum" supported. square_dice : bool, default=False Use squared Dice loss variant (not implemented with ignore_label). weight_ce : float, default=1 Weight for the Cross Entropy component. weight_dice : float, default=1 Weight for the Dice component. log_dice : bool, default=False If True, apply negative log transform to Dice loss. ignore_label : int, optional Label to ignore during loss calculation. name : str, optional Name of the metric. """ super(DC_and_CE_loss, self).__init__() if ignore_label is not None: assert not square_dice, 'not implemented' ce_kwargs['reduction'] = 'none' self.log_dice = log_dice self.weight_dice = weight_dice self.weight_ce = weight_ce self.aggregate = aggregate self.ce = RobustCrossEntropyLoss(**ce_kwargs) self.ignore_label = ignore_label if not square_dice: self.dc = SoftDiceLoss(apply_nonlin=lambda x: F.softmax(x, 1), **soft_dice_kwargs) self.name = name
[docs] def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute the combined Dice and Cross Entropy loss. Parameters ---------- net_output : torch.Tensor Model predictions (logits), shape (batch_size, num_classes, ...). target : torch.Tensor Ground truth tensor, shape (batch_size, 1, ...) or one-hot encoded. Returns ------- self.val: torch.Tensor Combined loss scalar. """ if target.shape[1] != 1: target = target.argmax(dim=1).long().unsqueeze(dim=1) if self.ignore_label is not None: assert target.shape[1] == 1, 'not implemented for one hot encoding' mask = target != self.ignore_label target[~mask] = 0 mask = mask.float() else: mask = None dc_loss = self.dc(net_output, target, loss_mask=mask) if self.weight_dice != 0 else 0 if self.log_dice: dc_loss = -torch.log(-dc_loss) ce_loss = self.ce(net_output, target[:, 0].long()) if self.weight_ce != 0 else 0 if self.ignore_label is not None: ce_loss *= mask[:, 0] ce_loss = ce_loss.sum() / mask.sum() if self.aggregate == "sum": result = self.weight_ce * ce_loss + self.weight_dice * dc_loss else: raise NotImplementedError("nah son") # reserved for other stuff (later) self.val = result return self.val
#---------------------------------------------------------------------------