Source code for biom3d.trainers

"""
The Trainers are Python functions that take as input a dataloader, a model, a loss function and an optimizer function to start a training process. Optionally, a list of biom3d.metrics.Metric and a biom3d.callback.Callbacks can be provided to the trainer to enrich the training loop.

Validaters, which optionally perform validation in the end of each epoch, are also defined in biom3d.trainers.
"""
# TODO: re-structure with classes maybe? Don't feel necessary for the moment, maybe to define an interface to ease extension.

import torch 
from tqdm import tqdm
from time import time 
from contextlib import nullcontext

from typing import Any
from torch.utils.data.dataloader import DataLoader
from torch.amp.grad_scaler import GradScaler
from torch.optim import Optimizer
from torch.nn import Module
from biom3d.metrics import Metric
from biom3d.callbacks import Callbacks
#---------------------------------------------------------------------------
# model trainers for segmentation

[docs] def seg_train( dataloader:DataLoader, scaler:GradScaler, model:Module, loss_fn:Metric, metrics:list[Metric], optimizer:Optimizer, callbacks:Callbacks, epoch:int | None = None, # required by deep supervision use_deep_supervision:bool=False, )->None: """ Train a segmentation model. Call the dataloader to get a batch of images and masks, pass through the model, compute the loss using model output and masks, update model parameters. Work with both CUDA, Metal or CPU. CPU is much slower. Work with half precision (fp16, CUDA only) and with standard precision (fp32). Use gradient clipping during backpropagation. Parameters ---------- dataloader : DataLoader DataLoader for training data. A Dataloader is a Python class with an overloaded `__getitem__` method. In this case, `__getitem__` should return a batch of images and a batch of masks. scaler : torch.amp.GradScaler For halp precision. model : torch.nn.Module The model to train. loss_fn : biom3d.metrics.Metric The loss function. metrics : list of biom3d.metrics.Metric List of metrics to compute during training. optimizer : torch.optim.Optimizer The optimizer used for training. callbacks : biom3d.callbacks.Callbacks Callbacks to be called during training. epoch : int, optional Current epoch number, required for deep supervision. use_deep_supervision : bool, default=False If True, deep supervision is used during training. Returns ------- None """ model.train() if torch.cuda.is_available(): torch.cuda.synchronize() elif torch.backends.mps.is_available(): torch.mps.synchronize() t_start_epoch = time() print("[time] start epoch") for batch, (X, y) in enumerate(dataloader): callbacks.on_batch_begin(batch) if torch.cuda.is_available(): X, y = X.cuda(), y.cuda() torch.cuda.synchronize() if torch.backends.mps.is_available(): X, y = X.to('mps'), y.to('mps') torch.mps.synchronize() t_data_loading = time() batch_duration = t_data_loading - t_start_epoch if batch_duration > 1: print(f"[Warning] Batch {batch} took {batch_duration:.2f}s — possible slowdown.") # Compute prediction error with torch.amp.autocast("cuda") if scaler is not None and torch.cuda.is_available() else nullcontext(): pred = model(X); del X loss = loss_fn(pred, y) with torch.no_grad(): if use_deep_supervision: for m in metrics: m(pred[-1],y) else: for m in metrics: m(pred,y) # Backpropagation optimizer.zero_grad() # set gradient to zero, why is that needed? if scaler is not None: scaler.scale(loss).backward() # compute gradient scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 12) scaler.step(optimizer) # apply gradient scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 12) optimizer.step() del loss, pred, y callbacks.on_batch_end(batch) if torch.cuda.is_available(): torch.cuda.synchronize() elif torch.backends.mps.is_available(): torch.mps.synchronize() t_start_epoch = time() if torch.cuda.is_available(): torch.cuda.empty_cache() elif torch.backends.mps.is_available(): torch.mps.empty_cache()
[docs] def seg_validate( dataloader:DataLoader, model:Module, loss_fn:Metric, metrics:list[Metric], use_fp16:bool, use_deep_supervision:bool=False, )->None: """ Validate a segmentation model. Call the validation dataloader to get a batch of images and masks, pass through the model, compute the loss using model output and masks. Work with both CUDA, Metal or CPU. CPU is much slower. Work with half precision (fp16, CUDA only) and with standard precision (fp32). Parameters ---------- dataloader : DataLoader DataLoader for validation data. A Dataloader is a Python class with an overloaded `__getitem__` method. In this case, `__getitem__` should return a batch of images and a batch of masks. model : torch.nn.Module The model to validate. loss_fn : biom3d.metrics.Metric The validation loss function. metrics : list of biom3d.metrics.Metric List of metrics to compute during validation. use_fp16 : bool Flag to indicate if half-precision (fp16) is used. use_deep_supervision : bool, default=False If True, deep supervision is used during validation. Returns ------- None """ for m in [loss_fn]+metrics: m.reset() # reset metrics model.eval() # set the module in evaluation mode (only useful for dropout or batchnorm like layers) with torch.no_grad(): # set all the requires_grad flags to zeros for X, y in dataloader: # with CUDA if torch.cuda.is_available(): X, y = X.cuda(), y.cuda() elif torch.backends.mps.is_available(): X,y = X.to('mps'), y.to('mps') with torch.amp.autocast("cuda") if use_fp16 and torch.cuda.is_available else nullcontext(): pred=model(X) del X loss_fn(pred, y) loss_fn.update() for m in metrics: if use_deep_supervision: m(pred[-1],y) else: m(pred, y) m.update() del pred, y if torch.cuda.is_available(): torch.cuda.empty_cache() elif torch.backends.mps.is_available(): torch.mps.empty_cache() template = "val error: avg loss {:.3f}".format(loss_fn.avg.item()) for m in metrics: template += ", " + str(m) print(template)
#--------------------------------------------------------------------------- # model trainers for segmentation with patches
[docs] def seg_patch_validate(dataloader:DataLoader, model:Module, loss_fn:Metric, metrics:list[Metric], **kwargs:dict[str,Any], )->None: """ Validate the segmentation model with TorchIO patch-based approach. Parameters ---------- dataloader : TorchIO DataLoader TorchIO DataLoader (such as generated using biom3d.datasets.semseg_torchio) containing validation data in patches. A Dataloader is a Python class with an overloaded `__getitem__` method. In this case, `__getitem__` should return a batch of images and a batch of masks. model : torch.nn.Module The model to validate. loss_fn : biom3d.metrics.Metric The validation loss function. metrics : list of biom3d.metrics.Metric List of metrics to compute during validation. **kwargs: dict from str to any Just for compatibility. Returns ------- None """ print("Start validation...") for m in [loss_fn]+metrics: m.reset() # reset metrics model.eval() # set the module in evaluation mode (only useful for dropout or batchnorm like layers) with torch.no_grad(): # set all the requires_grad flags to zeros for it in tqdm(dataloader): patch_loader = torch.utils.data.DataLoader(it, batch_size=dataloader.batch_size,num_workers=0) for (X,y) in patch_loader: if torch.cuda.is_available(): X, y = X.cuda(), y.cuda() elif torch.backends.mps.is_available(): X,y = X.to('mps'), y.to('mps') pred=model(X).detach() loss_fn(pred, y) loss_fn.update() for m in metrics: m(pred, y) m.update() template = "val error: avg loss {:.3f}".format(loss_fn.avg.item()) for m in metrics: template += ", " + str(m) print(template)
[docs] def seg_patch_train( dataloader:DataLoader, model:Module, loss_fn:Metric, metrics:list[Metric], optimizer:Optimizer, callbacks:Callbacks, epoch: int | None = None, # required by deep supervision use_deep_supervision:bool=False, **kwargs:dict[str,Any], )->None: """ Train the segmentation model using a TorchIO patch-based approach. Parameters ---------- dataloader : TorchIO DataLoader TorchIO DataLoader (such as generated using biom3d.datasets.semseg_torchio) containing training data in patches. A Dataloader is a Python class with an overloaded `__getitem__` method. In this case, `__getitem__` should return a batch of images and a batch of masks. model : torch.nn.Module The model to train. loss_fn : biom3d.metrics.Metric The loss function. metrics : list of metrics List of metrics to calculate during patch-based training. optimizer : torch.optim.Optimizer The optimizer used for training. callbacks : biom3d.callbacks.Callbacks Callbacks to be called during training. epoch : int, optional Current epoch number, required for deep supervision. use_deep_supervision : bool, default=False If True, deep supervision is used during training. **kwargs: dict from str to any Just for compatibility Returns ------- None """ model.train() for batch, queue in enumerate(dataloader): patch_loader = torch.utils.data.DataLoader(queue, batch_size =dataloader.batch_size, drop_last =False, shuffle =False, num_workers =0, pin_memory =True) for patch, (X, y) in enumerate(patch_loader): callbacks.on_batch_begin(batch * len(patch_loader) + patch) if torch.cuda.is_available(): X, y = X.cuda(), y.cuda() elif torch.backends.mps.is_available(): X, y = X.to('mps'), y.to('mps') # Compute prediction error pred = model(X) if use_deep_supervision: loss = loss_fn(pred, y, epoch) for m in metrics: m(pred[-1].detach(),y.detach(), epoch) else: loss = loss_fn(pred, y) for m in metrics: m(pred.detach(),y.detach()) # Backpropagation : optimizer.zero_grad() # set gradient to zero, why is that needed? loss.backward() # compute gradient optimizer.step() # apply gradient loss.detach() callbacks.on_batch_end(batch * len(patch_loader) + patch)
#---------------------------------------------------------------------------