Trainers¶
The Trainers are Python functions that take as input a dataloader, a model, a loss function and an optimizer function to start a training process. Optionally, a list of biom3d.metrics.Metric and a biom3d.callback.Callbacks can be provided to the trainer to enrich the training loop.
Validaters, which optionally perform validation in the end of each epoch, are also defined in biom3d.trainers.
- biom3d.trainers.seg_patch_train(dataloader: torch.utils.data.dataloader.DataLoader, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], optimizer: torch.optim.Optimizer, callbacks: Callbacks, epoch: int | None = None, use_deep_supervision: bool = False, **kwargs: dict[str, Any]) None[source]¶
Train the segmentation model using a TorchIO patch-based approach.
- Parameters:
dataloader (TorchIO DataLoader) – TorchIO DataLoader (such as generated using biom3d.datasets.semseg_torchio) containing training data in patches. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
model (torch.nn.Module) – The model to train.
loss_fn (biom3d.metrics.Metric) – The loss function.
metrics (list of metrics) – List of metrics to calculate during patch-based training.
optimizer (torch.optim.Optimizer) – The optimizer used for training.
callbacks (biom3d.callbacks.Callbacks) – Callbacks to be called during training.
epoch (int, optional) – Current epoch number, required for deep supervision.
use_deep_supervision (bool, default=False) – If True, deep supervision is used during training.
**kwargs (dict from str to any) – Just for compatibility
- Return type:
None
- biom3d.trainers.seg_patch_validate(dataloader: torch.utils.data.dataloader.DataLoader, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], **kwargs: dict[str, Any]) None[source]¶
Validate the segmentation model with TorchIO patch-based approach.
- Parameters:
dataloader (TorchIO DataLoader) – TorchIO DataLoader (such as generated using biom3d.datasets.semseg_torchio) containing validation data in patches. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
model (torch.nn.Module) – The model to validate.
loss_fn (biom3d.metrics.Metric) – The validation loss function.
metrics (list of biom3d.metrics.Metric) – List of metrics to compute during validation.
**kwargs (dict from str to any) – Just for compatibility.
- Return type:
None
- biom3d.trainers.seg_train(dataloader: torch.utils.data.dataloader.DataLoader, scaler: torch.amp.grad_scaler.GradScaler, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], optimizer: torch.optim.Optimizer, callbacks: Callbacks, epoch: int | None = None, use_deep_supervision: bool = False) None[source]¶
Train a segmentation model.
Call the dataloader to get a batch of images and masks, pass through the model, compute the loss using model output and masks, update model parameters.
Work with both CUDA, Metal or CPU. CPU is much slower.
Work with half precision (fp16, CUDA only) and with standard precision (fp32).
Use gradient clipping during backpropagation.
- Parameters:
dataloader (DataLoader) – DataLoader for training data. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
scaler (torch.amp.GradScaler) – For halp precision.
model (torch.nn.Module) – The model to train.
loss_fn (biom3d.metrics.Metric) – The loss function.
metrics (list of biom3d.metrics.Metric) – List of metrics to compute during training.
optimizer (torch.optim.Optimizer) – The optimizer used for training.
callbacks (biom3d.callbacks.Callbacks) – Callbacks to be called during training.
epoch (int, optional) – Current epoch number, required for deep supervision.
use_deep_supervision (bool, default=False) – If True, deep supervision is used during training.
- Return type:
None
- biom3d.trainers.seg_validate(dataloader: torch.utils.data.dataloader.DataLoader, model: torch.nn.Module, loss_fn: Metric, metrics: list[Metric], use_fp16: bool, use_deep_supervision: bool = False) None[source]¶
Validate a segmentation model.
Call the validation dataloader to get a batch of images and masks, pass through the model, compute the loss using model output and masks.
Work with both CUDA, Metal or CPU. CPU is much slower.
Work with half precision (fp16, CUDA only) and with standard precision (fp32).
- Parameters:
dataloader (DataLoader) – DataLoader for validation data. A Dataloader is a Python class with an overloaded __getitem__ method. In this case, __getitem__ should return a batch of images and a batch of masks.
model (torch.nn.Module) – The model to validate.
loss_fn (biom3d.metrics.Metric) – The validation loss function.
metrics (list of biom3d.metrics.Metric) – List of metrics to compute during validation.
use_fp16 (bool) – Flag to indicate if half-precision (fp16) is used.
use_deep_supervision (bool, default=False) – If True, deep supervision is used during validation.
- Return type:
None