Metrics

The metrics are used to both train the model, they will be called ‘loss’, and monitor the training process. All metrics must inherit from the biom3d.metrics.Metric class. Once defined a novel metric can be either used as a loss function and integrated in the config file as follows:

>>> TRAIN_LOSS = Dict(
>>>     fct="DiceBCE",
>>>     kwargs = Dict(name="train_loss", use_softmax=True)
>>> )

or as a metric with the following:

>>> VAL_METRICS = Dict(
>>>     val_iou=Dict(
>>>         fct="IoU",
>>>         kwargs = Dict(name="val_iou", use_softmax=USE_SOFTMAX)),
>>>     val_dice=Dict(
>>>         fct="Dice",
>>>         kwargs=Dict(name="val_dice", use_softmax=USE_SOFTMAX)),
>>> )

Metrics/losses.

Mostly for segmentation.

class biom3d.metrics.CrossEntropy(*args: Any, **kwargs: Any)[source]

Cross-entropy loss metric.

This metric computes the average cross-entropy between predicted class scores and target class indices. Typically used for classification problems.

Variables:
  • name (str) – Name of the metric (used in logging or display).

  • ce (torch.nn.CrossEntropyLoss) – Internal cross-entropy loss module.

__init__(name: str | None = None)[source]

Initialize the CrossEntropy metric.

Parameters:

name (str, optional) – Name of the metric, for display or logging purposes.

forward(inputs: torch.Tensor, targets: torch.Tensor) torch.Tensor[source]

Compute the cross-entropy loss between predictions and targets.

Parameters:
  • inputs (torch.Tensor) – Raw class scores (logits) of shape (N, C, …).

  • targets (torch.Tensor) – Ground truth class indices of shape (N, …).

Returns:

Computed cross-entropy loss.

Return type:

torch.Tensor

class biom3d.metrics.DC_and_CE_loss(*args: Any, **kwargs: Any)[source]

Combined Dice and Cross-Entropy loss metric.

This metric combines Soft Dice loss and Cross Entropy loss, with optional weighting and masking support. Weights for CE and Dice do not need to sum to one and can be set independently.

Variables:
  • ce (RobustCrossEntropyLoss) – Cross entropy loss module.

  • dc (SoftDiceLoss) – Soft Dice loss module.

  • weight_ce (float) – Weight for the Cross Entropy loss.

  • weight_dice (float) – Weight for the Dice loss.

  • log_dice (bool) – Whether to log-transform the Dice loss.

  • ignore_label (int | None) – Label to ignore during loss calculation.

  • aggregate (str) – Method to aggregate losses (currently supports “sum” only).

  • name (str) – Name of the metric.

__init__(soft_dice_kwargs: dict[str, Any], ce_kwargs: dict[str, Any], aggregate: Literal['sum'] = 'sum', square_dice: bool = False, weight_ce: float = 1.0, weight_dice: float = 1.0, log_dice: bool = False, ignore_label: int | None = None, name: str | None = None)[source]

Initialize the combined Dice and Cross Entropy loss metric.

Parameters:
  • soft_dice_kwargs (dict of str to any) – Keyword arguments for SoftDiceLoss.

  • ce_kwargs (dict of str to any) – Keyword arguments for RobustCrossEntropyLoss.

  • aggregate ("sum", default="sum") – Aggregation method for losses. Only “sum” supported.

  • square_dice (bool, default=False) – Use squared Dice loss variant (not implemented with ignore_label).

  • weight_ce (float, default=1) – Weight for the Cross Entropy component.

  • weight_dice (float, default=1) – Weight for the Dice component.

  • log_dice (bool, default=False) – If True, apply negative log transform to Dice loss.

  • ignore_label (int, optional) – Label to ignore during loss calculation.

  • name (str, optional) – Name of the metric.

forward(net_output: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Compute the combined Dice and Cross Entropy loss.

Parameters:
  • net_output (torch.Tensor) – Model predictions (logits), shape (batch_size, num_classes, …).

  • target (torch.Tensor) – Ground truth tensor, shape (batch_size, 1, …) or one-hot encoded.

Returns:

self.val – Combined loss scalar.

Return type:

torch.Tensor

class biom3d.metrics.DeepMetric(*args: Any, **kwargs: Any)[source]

Metric wrapper for deep supervision.

Applies a given metric across multiple feature maps (from different decoder depths, e.g., in U-Net), combining them using a weighted sum of coefficients (alphas). Each feature map prediction is compared to the same ground truth.

Variables:
  • metric (Metric) – Base metric applied at each level.

  • alphas (list[float]) – Weights associated with each level’s output.

  • name (str) – Name of the metric.

__init__(metric: type[Metric], alphas: list[float], name: str | None = None, metric_kwargs: dict[str, Any] | None = None)[source]

Initialize the DeepMetric.

Parameters:
  • metric (type[Metric]) – A callable class or constructor of a Metric to apply at each level.

  • alphas (list of float) – List of coefficients for each prediction level. Must match the number of inputs.

  • name (str, optional) – Name of the metric (used for logging/display).

  • metric_kwargs (dict, optional) – Additional keyword arguments for the base metric constructor.

forward(inputs: list[torch.Tensor], targets: torch.Tensor) torch.Tensor[source]

Compute the deep supervision metric across multiple prediction levels.

Parameters:
  • inputs (list of torch.Tensor) – List of model outputs at various depths.

  • targets (torch.Tensor) – Ground truth labels shared across all prediction levels.

Returns:

self.val – Weighted sum of metric values across levels.

Return type:

torch.Tensor

class biom3d.metrics.Dice(*args: Any, **kwargs: Any)[source]

Dice score computation metric.

Computes the Dice coefficient between predictions and targets. Supports binary and multi-class segmentation with optional softmax normalization. Background channel can be automatically removed for multi-class setups.

Variables:
  • name (str) – Name of the metric (used in logs).

  • use_softmax (bool) – Whether to apply softmax before Dice computation.

  • dim (tuple[int]) – Dimensions over which Dice is computed (e.g., (2, 3) or (2, 3, 4)).

__init__(use_softmax: bool = False, dim: tuple[int] = (), name: str = None)[source]

Initialize the Dice metric.

Parameters:
  • use_softmax (bool, default=False) – Whether to apply softmax to model outputs. If True, background channel is removed.

  • dim (tuple, default=()) – Dimensions along which Dice is computed. Use (2,3) for 2D or (2,3,4) for 3D images.

  • name (str, optional) – Name of the metric for logging and display.

forward(inputs: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0) torch.Tensor[source]

Compute Dice coefficient between predictions and ground truth targets.

Parameters:
  • inputs (torch.Tensor) – Model outputs (logits or probabilities).

  • targets (torch.Tensor) – Ground truth segmentation masks.

  • smooth (float, default=1.0) – Smoothing constant to avoid division by zero.

Returns:

self.val – Dice score as a scalar tensor.

Return type:

torch.Tensor

class biom3d.metrics.DiceBCE(*args: Any, **kwargs: Any)[source]

Dice + Binary Cross-Entropy (BCE) metric.

This combined metric can also be used as a loss function. It computes the sum of the BCE loss and the Dice loss, supporting both binary and multi-class segmentation with optional softmax activation. For multi-class predictions, background is removed from Dice computation if use_softmax is True.

Variables:
  • name (str) – Name of the metric (used in logs).

  • use_softmax (bool) – Whether to apply softmax and remove background for Dice computation.

  • dim (tuple[int]) – Dimensions over which Dice is computed.

  • bce (torch.nn.CrossEntropyLoss) – BCE loss function module.

__init__(use_softmax: bool = False, dim: tuple[int] = (), name: str | None = None)[source]

Initialize the DiceBCE metric.

Parameters:
  • use_softmax (bool, default=False) – Whether to apply softmax to model outputs. If True, background is excluded from Dice.

  • dim (tuple, default=()) – Dimensions along which Dice is computed (e.g., (2, 3) or (2, 3, 4)).

  • name (str, optional) – Name of the metric for logging and display.

forward(inputs: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0) torch.Tensor[source]

Compute the combined Dice and BCE loss.

Parameters:
  • inputs (torch.Tensor) – Model outputs (logits).

  • targets (torch.Tensor) – Ground truth segmentation masks (can be one-hot or class indices).

  • smooth (float, default=1.0) – Smoothing constant to avoid division by zero in Dice computation.

Returns:

slef.val – The Dice + BCE.

Return type:

torch.Tensor

class biom3d.metrics.IoU(*args: Any, **kwargs: Any)[source]

Intersection over Union (IoU) score computation.

This metric measures the overlap between the predicted and ground truth segmentation masks. Can be used as a metric or a loss function (when inverted during training). Supports both binary and multi-class segmentation tasks, with optional softmax activation and background removal.

Variables:
  • name (str) – Name of the metric.

  • use_softmax (bool) – Whether to apply softmax (multi-class case).

  • dim (tuple) – Dimensions along which the IoU is computed.

__init__(use_softmax: bool = False, dim: tuple = (), name: str | None = None)[source]

Initialize the IoU metric.

Parameters:
  • use_softmax (bool, default=False) – Whether to apply softmax to model outputs. If True, background is excluded from IoU computation.

  • dim (tuple, default=()) – Dimensions over which IoU is computed (e.g., (2, 3, 4) for 3D or (2, 3) for 2D).

  • name (str, optional) – Name of the metric (for logging or display).

forward(inputs: torch.Tensor, targets: torch.Tensor, smooth: float = 1.0) torch.Tensor[source]

Compute the IoU between predicted and target masks.

Parameters:
  • inputs (torch.Tensor) – Logits output by the model.

  • targets (torch.Tensor) – Ground truth segmentation masks (either one-hot or class indices).

  • smooth (float, default=1.0) – Smoothing term to avoid division by zero.

Returns:

self.val – Computed IoU score.

Return type:

torch.Tensor

class biom3d.metrics.MSE(*args: Any, **kwargs: Any)[source]

Mean Squared Error (MSE) loss.

This metric computes the average squared difference between predictions and targets. Commonly used as a regression loss, but can also serve as a metric in training.

Variables:

name (str) – Name of the metric (for logging or display).

__init__(name: str | None = None)[source]

Initialize the MSE metric.

Parameters:

name (str, optional) – Name of the metric, for display or logging purposes.

forward(inputs: torch.Tensor, targets: torch.Tensor) torch.Tensor[source]

Compute the mean squared error between predictions and targets.

Parameters:
  • inputs (torch.Tensor) – Predicted values.

  • targets (torch.Tensor) – Ground truth values.

Returns:

Computed MSE loss.

Return type:

torch.Tensor

class biom3d.metrics.Metric(*args: Any, **kwargs: Any)[source]

Abstract base class for all metrics.

Designed to store and update metric values such as the current value, average, sum, and count. In the biom3d structure, all metrics must implement the following: - a name attribute for identification, - a reset() method to clear internal states, - a forward() method to compute the metric given predictions and targets, - an update() method to accumulate statistics.

To define a new metric: 1. Inherit from Metric. 2. Set the name attribute in __init__. 3. Implement forward() to compute the metric and assign to self.val.

Variables:
  • name (str) – Name of the metric (used in logging/display).

  • val (float) – Current value of the metric for the last batch.

  • avg (float) – Running average of the metric.

  • sum (float) – Cumulative sum of the metric across all updates.

  • count (int) – Number of updates applied (used to compute average).

__init__(name: str | None = None)[source]

Initialize the base Metric class.

Parameters:

name (str, optional) – Name of the metric, for display/logging purposes.

abstractmethod forward(preds: torch.Tensor, trues: torch.Tensor) None[source]

Compute the metric value for the given predictions and targets.

Must be implemented by subclass.

Parameters:
  • preds (torch.Tensor) – Model predictions.

  • trues (torch.Tensor) – Ground truth targets.

Return type:

None

reset() None[source]

Reset internal metric statistics.

Return type:

None

str() str[source]

Return string representation of the current value.

Returns:

String representation of the current .

Return type:

str

update(n: int = 1) None[source]

Update metric statistics with the current value.

Parameters:

n (int, default=1) – Number of samples in the current batch.

Return type:

None

class biom3d.metrics.RobustCrossEntropyLoss(*args: Any, **kwargs: Any)[source]

Compatibility wrapper for CrossEntropyLoss to handle target tensors with an extra singleton dimension and float dtype.

This class modifies the target tensor shape and type to fit the expected input of nn.CrossEntropyLoss, which requires targets to be LongTensor without extra dimensions.

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Compute the cross-entropy loss with robust target handling.

Parameters:
  • input (torch.Tensor) – Predictions (logits) of shape (B, C, …) where B is batch size and C number of classes.

  • target (torch.Tensor) – Target labels which might have an extra dimension (e.g. shape (B, 1, …)) and float dtype.

Returns:

Scalar loss value.

Return type:

torch.Tensor

class biom3d.metrics.SoftDiceLoss(*args: Any, **kwargs: Any)[source]

Soft Dice loss implementation.

This loss is useful for segmentation tasks, especially when dealing with class imbalance. It computes a soft version of the Dice coefficient, optionally over the batch dimension.

Parameters:
  • apply_nonlin (callable, optional) – Optional non-linearity to apply to the prediction (e.g., torch.softmax or torch.sigmoid).

  • batch_dice (bool, default=False) – If True, compute Dice over the entire batch instead of per-sample.

  • do_bg (bool, default=True) – If False, ignore background channel (assumed to be channel index 0).

  • smooth (float, default=1.) – Smoothing factor added to numerator and denominator to avoid division by zero.

__init__(apply_nonlin: Callable | None = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.0)[source]

Initialize the SoftDiceLoss.

Parameters:
  • apply_nonlin (callable, optional) – Non-linearity function to apply to the predictions (e.g., torch.sigmoid or torch.softmax). If None, no activation is applied.

  • batch_dice (bool, default=False) – If True, computes Dice over the entire batch as a whole; otherwise computes per sample.

  • do_bg (bool, default=True) – If False, excludes the background class (channel 0) from the Dice computation.

  • smooth (float, default=1.0) – Smoothing factor added to numerator and denominator to avoid division by zero.

forward(x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) torch.Tensor[source]

Compute the Soft Dice loss.

Parameters:
  • x (torch.Tensor) – Network predictions of shape (B, C, …) where B is batch size and C number of classes.

  • y (torch.Tensor) – Ground truth labels, either label maps or one-hot encoded tensors.

  • loss_mask (torch.Tensor, optional) – Mask tensor with shape (B, 1, …), where valid pixels are 1 and invalid are 0.

Returns:

-dc – Scalar loss value (negative mean Dice coefficient).

Return type:

torch.Tensor

biom3d.metrics.get_tp_fp_fn_tn(net_output: torch.Tensor, gt: torch.Tensor, axes: int | tuple[int] | list[int] | None = None, mask: torch.Tensor | None = None, square: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]

Compute true positives (TP), false positives (FP), false negatives (FN), and true negatives (TN) between network outputs and ground truth labels.

Assumes the input net_output is in shape (B, C, …) and gt is either a label map (B, …) or one-hot encoded (B, C, …). If not already in one-hot form, the function converts gt.

Parameters:
  • net_output (torch.Tensor) – Model prediction tensor with shape (B, C, H, W, (D)).

  • gt (torch.Tensor) – Ground truth tensor. Can be label map (B, H, W, (D)) or one-hot encoded (B, C, H, W, (D)).

  • axes (int or list/tuple of int, optional) – Axes to reduce over (e.g., spatial dimensions).

  • mask (torch.Tensor, optional) – Optional binary mask of shape (B, 1, H, W, (D)) where 1 indicates valid pixels.

  • square (bool, default=False) – Whether to square the TP/FP/FN/TN tensors before summation.

Returns:

  • tp torch.Tensor – True positives after reduction per class.

  • fp torch.Tensor – False positives after reduction per class.

  • tn torch.Tensor – True negatives after reduction per class.

  • fn torch.Tensor – False negatives after reduction per class.

biom3d.metrics.sum_tensor(inp: torch.Tensor, axes: int | tuple[int] | list[int], keepdim: bool = False) torch.Tensor[source]

Sum a tensor over specified axes.

Parameters:
  • inp (torch.Tensor) – Input tensor to reduce.

  • axes (int or list/tuple of int) – Axes over which the tensor is summed.

  • keepdim (bool, default=False) – Whether to retain reduced dimensions (with size 1).

Returns:

inp – Reduced tensor with summed values. No copy.

Return type:

torch.Tensor