Register¶
The register is a core module of Biom3d as it define what function can be used for every module.
Note
If you implement a new function/module, you may want to integrate it directly in Biom3d by adding it to the register. If so, you can add it’s documentation here and submit a pull request.
1#---------------------------------------------------------------------------
2# A register for all the existing methods
3# Aim of this module:
4# - to gather all required imports in a single file
5# - to use it in colaboration with a config file
6#---------------------------------------------------------------------------
7
8from biom3d.utils import AttrDict
9
10#---------------------------------------------------------------------------
11# dataset register
12
13from biom3d.datasets.semseg_patch_fast import SemSeg3DPatchFast
14from biom3d.datasets.semseg_torchio import TorchioDataset
15
16datasets = AttrDict(
17 SegPatchFast =AttrDict(fct=SemSeg3DPatchFast, kwargs=AttrDict()),
18 Torchio =AttrDict(fct=TorchioDataset, kwargs=AttrDict()),
19)
20
21try:
22 # Batchgen use nnUnet batchgenerator that may not be installed (it is not a dependency), do pip install batchgenerators
23 from biom3d.datasets.semseg_batchgen import MTBatchGenDataLoader
24 datasets.BatchGen = AttrDict(fct=MTBatchGenDataLoader, kwargs=AttrDict())
25except:
26 pass
27
28#---------------------------------------------------------------------------
29# model register
30
31from biom3d.models.encoder_vgg import VGGEncoder, EncoderBlock
32from biom3d.models.unet3d_vgg_deep import UNet
33from biom3d.models.encoder_efficientnet3d import EfficientNet3D
34from biom3d.models.unet3d_eff import EffUNet
35from monai.networks import nets
36
37models = AttrDict(
38 VGG3D =AttrDict(fct=VGGEncoder, kwargs=AttrDict(block=EncoderBlock, use_head=True)),
39 UNet3DVGGDeep =AttrDict(fct=UNet, kwargs=AttrDict()),
40 Eff3D =AttrDict(fct=EfficientNet3D.from_name, kwargs=AttrDict()),
41 EffUNet =AttrDict(fct=EffUNet, kwargs=AttrDict()),
42 SwinUNETR =AttrDict(fct=nets.SwinUNETR, kwargs=AttrDict()),
43)
44
45#---------------------------------------------------------------------------
46# metric register
47
48import biom3d.metrics as mt
49
50metrics = AttrDict(
51 Dice =AttrDict(fct=mt.Dice, kwargs=AttrDict()),
52 DiceBCE =AttrDict(fct=mt.DiceBCE, kwargs=AttrDict()),
53 DiceCEnnUNet=AttrDict(fct=mt.DC_and_CE_loss, kwargs=AttrDict(soft_dice_kwargs={'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, ce_kwargs={})),
54 IoU =AttrDict(fct=mt.IoU, kwargs=AttrDict()),
55 MSE =AttrDict(fct=mt.MSE, kwargs=AttrDict()),
56 CE =AttrDict(fct=mt.CrossEntropy, kwargs=AttrDict()),
57 DeepMSE =AttrDict(fct=mt.DeepMetric, kwargs=AttrDict(metric=mt.MSE)),
58 DeepDiceBCE =AttrDict(fct=mt.DeepMetric, kwargs=AttrDict(metric=mt.DiceBCE)),
59)
60
61#---------------------------------------------------------------------------
62# trainer register
63
64from biom3d.trainers import (
65 seg_train,
66 seg_validate,
67 seg_patch_validate,
68 seg_patch_train,
69)
70
71trainers = AttrDict(
72 SegTrain =AttrDict(fct=seg_train, kwargs=AttrDict()),
73 SegVal =AttrDict(fct=seg_validate, kwargs=AttrDict()),
74 SegPatchTrain =AttrDict(fct=seg_patch_train, kwargs=AttrDict()),
75 SegPatchVal =AttrDict(fct=seg_patch_validate, kwargs=AttrDict()),
76)
77
78#---------------------------------------------------------------------------
79# Preprocessor and predictor register
80# We register preprocessors here because they are needed to preprocess
81# data before prediction.
82# Preprocessor must correspond to the one used to preprocess data
83# before training.
84
85from biom3d.preprocess import seg_preprocessor
86
87preprocessors = AttrDict(
88 Seg = AttrDict(fct=seg_preprocessor, kwargs=AttrDict())
89)
90
91from biom3d.predictors import (
92 seg_predict,
93 seg_predict_patch_2,
94)
95
96predictors = AttrDict(
97 Seg = AttrDict(fct=seg_predict, kwargs=AttrDict()),
98 SegPatch = AttrDict(fct=seg_predict_patch_2, kwargs=AttrDict()),
99)
100
101from biom3d.predictors import seg_postprocessing
102
103postprocessors = AttrDict(
104 Seg = AttrDict(fct=seg_postprocessing, kwargs=AttrDict())
105)
106
Now let’s delve on each modules.
Dataloader and batchgenerator¶
There are currently 2 dataloader and 1 batchgenerator.
SegPatchFast¶
This is the default dataloading module.
- class biom3d.datasets.semseg_patch_fast.SemSeg3DPatchFast(*args: Any, **kwargs: Any)[source]
Dataset class for semantic segmentation with 3D patches. Supports data augmentation and efficient loading.
- Variables:
img_path (str) – Path to collection containing the image files.
msk_path (str) – Path to collection containing the mask files.
fg_path (str | None) – Path to collection containing the foreground files.
batch_size (int) – Size of a batch.
patch_size (numpy.ndarray) – Size of a patch.
aug_patch_size (numpy.ndarray | None) – Size of augmented patch size, may be bigger than patch size.
nbof_steps (int) – Number of steps (batches) per epoch.
load_data (bool) – If True, load the entire dataset into memory.
handler (DataHandler) – DataHandler used to load data.
train (bool) – If True, use the dataset for training; otherwise, use it for validation.
fnames (list[str]) – List of image paths relative to img_path.
use_aug (bool) – Whether to use data augmentation
fg_rate (float) – Foreground rate, used to force foreground inclusion in patches.
crop_scale (float) – Scale factor for crop size during augmentation.
use_softmax (bool) – If True, use softmax activation.
batch_idx (int) – Current batch index.
Torchio¶
This dataloader is an implementation of torchio subjectloader, optimized to do operation with torchio. It was written to ease data augmentation but is not used.
- class biom3d.datasets.semseg_torchio.TorchioDataset(*args: Any, **kwargs: Any)[source]
Custom dataset similar to torchio.SubjectsDataset but supports an unlimited number of steps (batches) per epoch.
Handles loading of images, masks, and foreground data, train/validation splitting, optional in-memory data loading, and specific data augmentations.
- Variables:
img_path (str) – Path to the collection containing image files.
msk_path (str) – Path to the collection containing mask files.
fg_path (Optional[str]) – Path to the collection containing foreground data (optional).
batch_size (int) – Batch size for sampling.
patch_size (numpy.ndarray) – Size of the patches to extract.
aug_patch_size (Optional[numpy.ndarray]) – Size of patches used for augmentation (optional). Can be larger than patch_size
nbof_steps (int) – Number of steps (batches) per epoch.
load_data (bool) – Whether to load all data into memory.
handler (DataHandler) – Data handler for loading images and masks.
train (bool) – Indicates if the dataset is used for training (True) or validation (False).
fnames (list[str]) – List of filenames used depending on training or validation mode.
subjects_list (list[Subject]) – List of TorchIO Subjects created from the files.
use_aug (bool) – Whether data augmentations are enabled.
fg_rate (float) – Foreground inclusion rate to force foreground sampling in patches.
use_softmax (bool) – Whether to use softmax activation; if False, sigmoid is used.
batch_idx (int) – Current batch index for internal tracking.
BatchGen¶
This is nnUnet batchgenerator. It has been slightly modified to act like a dataloader to be easier to interchange with the others.
Note
If you want to use it, you have to modify the config file in a different way:
TRAIN_DATALOADER = Dict(
fct="BatchGen",
kwargs=Dict(
# Insert parameters here
)
)
VAL_DATALOADER = Dict(
fct="BatchGen",
kwargs=Dict(
# Insert parameters here
)
)
# Instead of
TRAIN_DATASET = Dict(
fct="SegPatchFast",
kwargs=Dict(
# Insert parameters here
)
)
VAL_DATASET = Dict(
fct="SegPatchFast",
kwargs=Dict(
# Insert parameters here
)
)
- class biom3d.datasets.semseg_batchgen.MTBatchGenDataLoader(*args: Any, **kwargs: Any)[source]
Multi-threaded data loader for efficient data augmentation and loading.
- Variables:
length (int) – Number of batches.
Note
This class use a dependency that is not in Biom3d’s dependency so you will need to install it manually : pip install batchgenerators
Models¶
- class biom3d.models.unet3d_vgg_deep.UNet[source]
This is a transcription of nnUnet neural network, and is the default model used by Biom3d
- class biom3d.models.unet3d_vgg_deep.UNet(*args: Any, **kwargs: Any)[source]
A 3D UNet architecture utilizing VGG-style encoder and decoder blocks for volumetric (3D) image segmentation.
The UNet model is a convolutional neural network for fast and precise segmentation of images. This implementation incorporates VGG blocks for encoding and decoding, allowing for deep feature extraction and reconstruction, respectively. The model supports dynamic adjustment of pooling layers and class numbers, along with optional deep decoder usage and weight initialization from pre-trained checkpoints.
- Variables:
encoder (VGGEncoder) – The encoder part of the UNet, responsible for downscaling and feature extraction.
decoder (VGGDecoder) – The decoder part of the UNet, responsible for upscaling and constructing the segmentation map.
It use VGGEncoder as encoder :
- class biom3d.models.encoder_vgg.VGGEncoder(*args: Any, **kwargs: Any)[source]
VGG-style 3D encoder composed of multiple EncoderBlocks.
The architecture applies a sequence of blocks with progressively increasing number of channels, with configurable pooling and strides.
- Variables:
in_planes (int) – Number of input channels to the current layer.
use_emb (bool) – Whether embedding is used.
use_head (bool) – Whether fully connected head is used.
layers (ModuleList) – ModuleList containing the sequence of encoder layers.
head (nn.Sequential) – Optional fully connected head for embedding (if use_head is True).
And it use VGGDecoder as decoder
This is a transcription of nnUnet neural network, and is the default model used by Biom3d
- class biom3d.models.unet3d_eff.EfficientNet3D(*args: Any, **kwargs: Any)[source]
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods.
- Variables:
_blocks (nn.ModuleList) – List of MBConvBlock3D blocks making up the network.
_conv_stem (nn.Conv3d) – Initial convolutional stem layer.
_conv_head (nn.Conv3d) – Final convolutional head layer before classifier.
_bn0 (nn.InstanceNorm3d) – Batch normalization after the stem.
_bn1 (nn.BatchNorm3d) – Batch normalization after the head.
_avg_pooling (nn.AdaptiveAvgPool3d) – Global average pooling layer.
_dropout (nn.Dropout) – Dropout layer before the classifier.
_fc (nn.Linear) – Fully connected linear layer for classification.
_swish (nn.Module) – Swish activation function module.
Examples
>>> model = EfficientNet3D.from_name('efficientnet-b0') >>> x = torch.randn(1, 3, 224, 224, 224) >>> logits = model(x)
It use EfficientNet3D as encoder :
- class biom3d.models.encoder_efficientnet3d.EfficientNet3D(*args: Any, **kwargs: Any)[source]
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods.
- Variables:
_blocks (nn.ModuleList) – List of MBConvBlock3D blocks making up the network.
_conv_stem (nn.Conv3d) – Initial convolutional stem layer.
_conv_head (nn.Conv3d) – Final convolutional head layer before classifier.
_bn0 (nn.InstanceNorm3d) – Batch normalization after the stem.
_bn1 (nn.BatchNorm3d) – Batch normalization after the head.
_avg_pooling (nn.AdaptiveAvgPool3d) – Global average pooling layer.
_dropout (nn.Dropout) – Dropout layer before the classifier.
_fc (nn.Linear) – Fully connected linear layer for classification.
_swish (nn.Module) – Swish activation function module.
Examples
>>> model = EfficientNet3D.from_name('efficientnet-b0') >>> x = torch.randn(1, 3, 224, 224, 224) >>> logits = model(x)
And it use VGGDecoder as decoder.
Note
Both the encoder can also be used as models.
Metrics¶
Here are the possibles metrics :
Dice¶
One of the most used metrics. It is a comparison between the two image returning a number between 0 and 1, the closer it is to 1, the closer are the images.
- class biom3d.metrics.Dice(*args: Any, **kwargs: Any)[source]
Dice score computation metric.
Computes the Dice coefficient between predictions and targets. Supports binary and multi-class segmentation with optional softmax normalization. Background channel can be automatically removed for multi-class setups.
- Variables:
name (str) – Name of the metric (used in logs).
use_softmax (bool) – Whether to apply softmax before Dice computation.
dim (tuple[int]) – Dimensions over which Dice is computed (e.g., (2, 3) or (2, 3, 4)).
CrossEntropy¶
Metric that compare the softmax of the logit and the mask.
- class biom3d.metrics.CrossEntropy(*args: Any, **kwargs: Any)[source]
Cross-entropy loss metric.
This metric computes the average cross-entropy between predicted class scores and target class indices. Typically used for classification problems.
- Variables:
name (str) – Name of the metric (used in logging or display).
ce (torch.nn.CrossEntropyLoss) – Internal cross-entropy loss module.
DiceBCE¶
Metric that ally Dice and CrossEntropy
- class biom3d.metrics.DiceBCE(*args: Any, **kwargs: Any)[source]
Dice + Binary Cross-Entropy (BCE) metric.
This combined metric can also be used as a loss function. It computes the sum of the BCE loss and the Dice loss, supporting both binary and multi-class segmentation with optional softmax activation. For multi-class predictions, background is removed from Dice computation if use_softmax is True.
- Variables:
name (str) – Name of the metric (used in logs).
use_softmax (bool) – Whether to apply softmax and remove background for Dice computation.
dim (tuple[int]) – Dimensions over which Dice is computed.
bce (torch.nn.CrossEntropyLoss) – BCE loss function module.
DC_and_CE_loss¶
nnUnet’s implementation of Dice and CrossEntropy, is more robust but doesn’t treat binary cases.
- class biom3d.metrics.DC_and_CE_loss(*args: Any, **kwargs: Any)[source]
Combined Dice and Cross-Entropy loss metric.
This metric combines Soft Dice loss and Cross Entropy loss, with optional weighting and masking support. Weights for CE and Dice do not need to sum to one and can be set independently.
- Variables:
ce (RobustCrossEntropyLoss) – Cross entropy loss module.
dc (SoftDiceLoss) – Soft Dice loss module.
weight_ce (float) – Weight for the Cross Entropy loss.
weight_dice (float) – Weight for the Dice loss.
log_dice (bool) – Whether to log-transform the Dice loss.
ignore_label (int | None) – Label to ignore during loss calculation.
aggregate (str) – Method to aggregate losses (currently supports “sum” only).
name (str) – Name of the metric.
IoU¶
One of the most used metrics. It is a comparison between the two image returning a number between 0 and 1, the closer it is to 1, the closer are the images.
It is close the Dice but with less weight on the intersection (eg: Dice of 0.5 while IoU of 0.66).
- class biom3d.metrics.IoU(*args: Any, **kwargs: Any)[source]
Intersection over Union (IoU) score computation.
This metric measures the overlap between the predicted and ground truth segmentation masks. Can be used as a metric or a loss function (when inverted during training). Supports both binary and multi-class segmentation tasks, with optional softmax activation and background removal.
- Variables:
name (str) – Name of the metric.
use_softmax (bool) – Whether to apply softmax (multi-class case).
dim (tuple) – Dimensions along which the IoU is computed.
MSE¶
Use mean square method to compute loss.
- class biom3d.metrics.MSE(*args: Any, **kwargs: Any)[source]
Mean Squared Error (MSE) loss.
This metric computes the average squared difference between predictions and targets. Commonly used as a regression loss, but can also serve as a metric in training.
- Variables:
name (str) – Name of the metric (for logging or display).
DeepMetric¶
A deep supervision metric. Can be used with DiceBCE and MSE.
- class biom3d.metrics.DeepMetric(*args: Any, **kwargs: Any)[source]
Metric wrapper for deep supervision.
Applies a given metric across multiple feature maps (from different decoder depths, e.g., in U-Net), combining them using a weighted sum of coefficients (alphas). Each feature map prediction is compared to the same ground truth.
- Variables:
metric (Metric) – Base metric applied at each level.
alphas (list[float]) – Weights associated with each level’s output.
name (str) – Name of the metric.
Trainers¶
Here are the possibles trainers :
SegTrain¶
Default trainer, do the whole training.
- biom3d.trainers.seg_train(dataloader: torch.utils.data.dataloader.DataLoader, scaler: torch.amp.grad_scaler.GradScaler, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], optimizer: torch.optim.Optimizer, callbacks: Callbacks, epoch: int | None = None, use_deep_supervision: bool = False) None[source]
Train a segmentation model.
Call the dataloader to get a batch of images and masks, pass through the model, compute the loss using model output and masks, update model parameters.
Work with both CUDA, Metal or CPU. CPU is much slower.
Work with half precision (fp16, CUDA only) and with standard precision (fp32).
Use gradient clipping during backpropagation.
- Parameters:
dataloader (DataLoader) – DataLoader for training data. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
scaler (torch.amp.GradScaler) – For halp precision.
model (torch.nn.Module) – The model to train.
loss_fn (biom3d.metrics.Metric) – The loss function.
metrics (list of biom3d.metrics.Metric) – List of metrics to compute during training.
optimizer (torch.optim.Optimizer) – The optimizer used for training.
callbacks (biom3d.callbacks.Callbacks) – Callbacks to be called during training.
epoch (int, optional) – Current epoch number, required for deep supervision.
use_deep_supervision (bool, default=False) – If True, deep supervision is used during training.
- Return type:
None
SegVal¶
Default validater.
- biom3d.trainers.seg_validate(dataloader: torch.utils.data.dataloader.DataLoader, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], use_fp16: bool, use_deep_supervision: bool = False) None[source]
Validate a segmentation model.
Call the validation dataloader to get a batch of images and masks, pass through the model, compute the loss using model output and masks.
Work with both CUDA, Metal or CPU. CPU is much slower.
Work with half precision (fp16, CUDA only) and with standard precision (fp32).
- Parameters:
dataloader (DataLoader) – DataLoader for validation data. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
model (torch.nn.Module) – The model to validate.
loss_fn (biom3d.metrics.Metric) – The validation loss function.
metrics (list of biom3d.metrics.Metric) – List of metrics to compute during validation.
use_fp16 (bool) – Flag to indicate if half-precision (fp16) is used.
use_deep_supervision (bool, default=False) – If True, deep supervision is used during validation.
- Return type:
None
SegPatchTrain¶
Torchio trainer, created to use Torchio datasets and patch approach.
- biom3d.trainers.seg_patch_train(dataloader: torch.utils.data.dataloader.DataLoader, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], optimizer: torch.optim.Optimizer, callbacks: Callbacks, epoch: int | None = None, use_deep_supervision: bool = False, **kwargs: dict[str, Any]) None[source]
Train the segmentation model using a TorchIO patch-based approach.
- Parameters:
dataloader (TorchIO DataLoader) – TorchIO DataLoader (such as generated using biom3d.datasets.semseg_torchio) containing training data in patches. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
model (torch.nn.Module) – The model to train.
loss_fn (biom3d.metrics.Metric) – The loss function.
metrics (list of metrics) – List of metrics to calculate during patch-based training.
optimizer (torch.optim.Optimizer) – The optimizer used for training.
callbacks (biom3d.callbacks.Callbacks) – Callbacks to be called during training.
epoch (int, optional) – Current epoch number, required for deep supervision.
use_deep_supervision (bool, default=False) – If True, deep supervision is used during training.
**kwargs (dict from str to any) – Just for compatibility
- Return type:
None
SegPatchVal¶
Torchio validater, created to use Torchio datasets and patch approach.
- biom3d.trainers.seg_patch_validate(dataloader: torch.utils.data.dataloader.DataLoader, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], **kwargs: dict[str, Any]) None[source]
Validate the segmentation model with TorchIO patch-based approach.
- Parameters:
dataloader (TorchIO DataLoader) – TorchIO DataLoader (such as generated using biom3d.datasets.semseg_torchio) containing validation data in patches. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
model (torch.nn.Module) – The model to validate.
loss_fn (biom3d.metrics.Metric) – The validation loss function.
metrics (list of biom3d.metrics.Metric) – List of metrics to compute during validation.
**kwargs (dict from str to any) – Just for compatibility.
- Return type:
None
Preprocessor¶
There is currently only one preprocessor :
- biom3d.preprocess.seg_preprocessor(img: ndarray, img_meta: dict[str, Any], num_classes: int, msk: ndarray = None, use_one_hot: bool = False, remove_bg: bool = False, median_spacing: list[float] | ndarray = [], clipping_bounds: list[float] | tuple[float, float] = [], intensity_moments: list[float] | tuple[float, float] = [], channel_axis: int = 0, num_channels: int = 1, seed: int = 42, is_2d: bool = False) tuple[ndarray, ndarray, dict[int, list[int]]] | tuple[ndarray, dict[str, Any]][source]
Perform a full preprocessing pipeline for segmentation images and masks.
This function orchestrates a series of steps:
Standardizes image and mask dimensions.
Validates and corrects the mask using robust heuristics.
Optionally one-hot encodes the mask.
Applies intensity transformations (clipping, normalization).
Resamples the data to a target spacing.
Computes foreground coordinates for patch sampling.
- Parameters:
img (numpy.ndarray) – The input image array. Can be 2D or 3D, with or without channel dimension.
img_meta (dict of str to any) – Dictionary containing image metadata, including the spacing field.
num_classes (int) – Number of segmentation classes. Required if msk is provided.
msk (numpy.ndarray, optional) – Segmentation mask corresponding to the image. Can be 2D or 3D.
use_one_hot (bool, default=False) – If True, the mask will be converted to one-hot encoding.
remove_bg (bool, default=False) – If True and use_one_hot is True, the background channel (0) is removed.
median_spacing (list or numpy.ndarray of float, optional) – Target spacing for resampling. If empty, resampling is skipped.
clipping_bounds (list or tuple of float, optional) – Tuple (min, max) to clip intensity values. If empty, no clipping is applied.
intensity_moments (list or tuple of float, optional) – Tuple (mean, std) for intensity normalization. If empty, stats are computed from image.
channel_axis (int, default=0) – Index of the channel axis in the input image.
num_channels (int, default=1) – Expected number of image channels after standardization.
seed (int, default=42) – Random seed for reproducibility in foreground sampling.
is_2d (bool, default=False) – If True, assumes the image and mask are 2D rather than 3D.
- Raises:
RuntimeError – If the mask format is invalid and cannot be corrected.
ValueError – If input dimensions are inconsistent with expected format.
- Returns:
If msk is provided, returns (img, msk, fg) –
- img: numpy.ndarray
Preprocessed image.
- msk:ndarray
Preprocessed segmentation mask.
fg:dict mapping class index -> array of sampled voxel coordinates
If msk is None, returns (img, img_meta) –
- img: numpy.ndarray
Preprocessed image
- img_meta:
Original metadata, with added original_shape
Notes
Foreground sampling is capped at 10,000 voxels per class.
Designed for use in biology and medical image segmentation pipelines.
Predictors¶
Here are the available predictors :
There is currently only one preprocessor :
Seg¶
Note
This predictor should not be used, it is only here for retrocompatibility sake.
- biom3d.predictors.seg_predict(img_path: str, model: Module, return_logit: bool = False) ndarray[source]
Run a prediction on given image.
Segmentation Predictor V1
Load an image from a given path, run model prediction, and return either the binarized segmentation mask or raw logits.
- Parameters:
img_path (str) – Path to the image file to predict.
model (torch.nn.Module) – The PyTorch segmentation model.
return_logit (bool, default=False) – If True, returns the raw model logits without post-processing.
- Returns:
Segmentation mask (binary values 0 or 255) or raw logits.
- Return type:
numpy.ndarray
SegPatch¶
Note
This is the default predictor.
- biom3d.predictors.seg_predict_patch_2(img: ndarray, original_shape: tuple[int], model: Module, patch_size: tuple[int], conserve_size: bool = False, tta: bool = False, num_workers: int = 4, enable_autocast: bool = True, use_softmax: bool = True, keep_biggest_only: bool = False, **kwargs) ndarray[source]
For one image, compute the model prediction, return the predicted logit.
Segmentation Predictor V2
Image are supposed to be preprocessed already, which is doable using biom3d.preprocess.seg_preprocessor.
- Parameters:
img (numpy.ndarray) – The preprocessed image to predict.
original_shape (tuple of int) – Original shape of the image.
model (torch.nn.Module) – The segmentation model.
patch_size (tuple of int) – Size of the patch used during training.
conserve_size (bool, default=False) – Force the logit to be the same size as the input. May be used if intended to not use post-processing.
tta (bool, default=False) – Test time augmentation.
num_workers (int, default=4) – Number of workers.
enable_autocast (bool, default=True) – Whether to use half-precision.
use_softmax (bool, default=True) – [DEPRECATED!] Whether softmax activation has been used for training.
keep_biggest_only (bool, default=True) – [DEPRECATED!] When true keeps the biggest object only in the output image.
**kwargs (dict from str to any) – Just here for compatibility.
- Returns:
The predicted segmentation mask or logit.
- Return type:
numpy.ndarray
Postprocessor¶
There is currently only one postprocessor :
- biom3d.predictors.seg_postprocessing(logit: ndarray | Tensor, original_shape: tuple[int], use_softmax: bool = True, force_softmax: bool = False, keep_big_only: bool = False, keep_biggest_only: bool = False, return_logit: bool = False, is_2d: bool = False, **kwargs)[source]
Post-process the logit (model output) to obtain the final segmentation mask. Can optionally remove some noise.
Recommended to be used after biom3d.predictors.seg_predict_patch_2.
- Parameters:
logit (torch.Tensor or numpy.ndarray) – The raw model output.
original_shape (tuple of int) – Shape to resize the output to.
use_softmax (bool, default=True) – Whether softmax was used for training.
force_softmax (bool, default=False) – Whether sigmoid was used for training and intended to convert to softmax-like output.
keep_big_only (bool, default=False) – Whether to keep the big objects only in the output. An Otsu threshold is used on the object volume distribution.
keep_biggest_only (bool, default=False) – When true keeps the biggest centered object only in the output.
return_logit (bool, optional) – Whether to return the logit. Resampling will be applied before.
is_2d (bool, default=False) – Whether the image is in 2D, only affect resizing.
**kwargs (dict from str to any) – Just here for compatibility.
- Raises:
AssertionError – If logit is not a numpy.ndarray or torch.Tensor.
- Returns:
The post-processed segmentation mask or logit.
- Return type:
numpy.ndarray