"""
Callback are periodically called during training.
There are currently 3 different periods:
- training
- epoch
- batch
"""
import torch
import os
from shutil import copyfile
import matplotlib.pyplot as plt
plt.switch_backend('Agg') # bug fix: change matplotlib backend
from torch.utils.tensorboard import SummaryWriter
from contextlib import nullcontext
from abc import abstractmethod
import numpy as np
from typing import Any, Literal, Optional
from biom3d.metrics import Metric
#----------------------------------------------------------------------------
# Base classes
[docs]
class Callback(object):
"""
Abstract base class used to build new callbacks.
Callback are periodically called during training.
There are currently 3 different periods:
- training
- epoch
- batch
Callbacks are called either before or after a period.
Each method starting by `on_` can be overridden. This method will be called at a certain time point during the training process.
For instance, the `on_epoch_end` method will be called in the end of each epoch.
:ivar list[biom3d.Metric] metrics
"""
metrics:list[Metric]
[docs]
def __init__(self):
"""Initilization of attributes."""
self.metrics = None
[docs]
def set_trainer(self, metrics:list[Metric])->None:
"""
Associate metrics or training state to this callback.
Parameters
----------
metrics : list of biom3d.Metric
list of used metrics or training information to be used by the callback.
Returns
-------
None
"""
self.metrics = metrics
[docs]
@abstractmethod
def on_batch_begin(self, batch:int)->None:
"""
Call before processing a batch.
Parameters
----------
batch: int
The batch index.
Returns
-------
None
"""
pass
[docs]
@abstractmethod
def on_batch_end(self, batch:int)->None:
"""
Call after processing a batch.
Parameters
----------
batch: int
The batch index.
Returns
-------
None
"""
pass
[docs]
@abstractmethod
def on_epoch_begin(self, epoch:int)->None:
"""
Call before processing an epoch.
Parameters
----------
epoch: int
The epoch index.
Returns
-------
None
"""
pass
[docs]
@abstractmethod
def on_epoch_end(self, epoch:int)->None:
"""
Call after processing an epoch.
Parameters
----------
epoch: int
The epoch index.
Returns
-------
None
"""
pass
[docs]
@abstractmethod
def on_train_begin(self, epoch:Optional[int]=None)->None:
"""
Call once at the beginning of training.
Parameters
----------
epoch: int
The epoch index.
Returns
-------
None
"""
pass
[docs]
@abstractmethod
def on_train_end(self, epoch:Optional[int]=None)->None:
"""
Call once at the end of training.
Parameters
----------
epoch: int
The epoch index.
Returns
-------
None
"""
pass
[docs]
class Callbacks(Callback):
"""
Aggregates and manages multiple callbacks, dispatching events to all contained callbacks.
This class acts as a container for multiple Callback instances. It forwards
all callback events (`on_batch_begin`, `on_batch_end`, `on_epoch_begin`,
`on_epoch_end`, `on_train_begin`, `on_train_end`) to each callback in the
collection.
This allows users to easily manage several callbacks in a modular way, for example
combining logging, checkpoint saving, metric computation, etc., all in one object.
:ivar dict[str,Callback] callbacks: A dictionary of callback.
Examples
--------
>>> cbs = Callbacks({
... "logger": LogSaver(...),
... "saver": ModelSaver(...),
... })
>>> cbs.on_epoch_end(10) # calls on_epoch_end(10) on both LogSaver and ModelSaver
"""
[docs]
def __init__(self, callbacks:dict[str,Any]):
"""
Initialize from a dictionary.
Parameters
----------
callbacks : dict of str to biom3d.callback.Callback
A dictionary of callbacks.
"""
super().__init__()
if isinstance(callbacks, Callbacks):
callbacks = callbacks.callbacks
self.callbacks = callbacks
if callbacks is None:
self.callbacks = {}
def __getitem__(self, name:str)->Callback:
"""
Get a callback by its name.
Parameters
----------
name : str
The name/key of the callback to retrieve.
Raises
------
KeyError
If the name does not exist in the callbacks dictionary.
Returns
-------
Callback
The callback instance associated with the given name.
"""
return self.callbacks[name]
[docs]
def set_trainer(self, trainer:list[Metric])->None:
"""
Set the trainer or metrics context for each callback.
Parameters
----------
trainer : list of biom3d.Metric
Trainer or metrics list to pass to callbacks.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.set_trainer(trainer)
[docs]
def on_batch_begin(self, batch:int)->None:
"""
Call before processing a batch. Forwards the call to all callbacks.
Parameters
----------
batch : int
Current batch index.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.on_batch_begin(batch)
[docs]
def on_batch_end(self, batch:int)->None:
"""
Call after processing a batch. Forwards the call to all callbacks.
Parameters
----------
batch : int
Current batch index.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.on_batch_end(batch)
[docs]
def on_epoch_begin(self, epoch:int)->None:
"""
Call before processing an epoch. Forwards the call to all callbacks.
Parameters
----------
epoch : int
Current epoch index.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.on_epoch_begin(epoch)
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Call after processing an epoch. Forwards the call to all callbacks.
Parameters
----------
epoch : int
Current epoch index.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.on_epoch_end(epoch)
[docs]
def on_train_begin(self, epoch:Optional[int]=None)->None:
"""
Call once at the start of training. Forwards the call to all callbacks.
Parameters
----------
epoch : int, optional
Starting epoch number, if any.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.on_train_begin(epoch)
[docs]
def on_train_end(self, epoch:Optional[int]=None)->None:
"""
Call once at the end of training. Forwards the call to all callbacks.
Parameters
----------
epoch : int, optional
Final epoch number, if any.
Returns
-------
None
"""
for callback in self.callbacks.values():
callback.on_train_end(epoch)
#----------------------------------------------------------------------------
# Savers
[docs]
class ModelSaver(Callback): # TODO: save best_only
"""
Saves the model, optimizer state, epoch, and loss at the end of each epoch.
Can also save the best model based on a monitored loss metric.
:ivar torch.nn.Module | list[torch.nn.Module] model: Modele to store, if list it is assumed to be [student,teacher].
:ivar torch.optim.Optimizer optimizer: torch optimizer.
:ivar str path: Name of the file representing the model.
:ivar str path_last: path + '.pth'
:ivar str path_best: path + '_best.pth'
:ivar int every_epoch: Period between save
:ivar bool save_best: Whether to save best or not.
:ivar float best_loss: Best loss value since beginning.
:ivar biom3d.Metric loss: Loss function.
:ivar biom3d.Metric saved_loss: Loss function to save alongside model.
"""
[docs]
def __init__(self,
model:torch.nn.Module,
optimizer:torch.optim.Optimizer,
path:str='unet',
every_epoch:int=2,
save_best:bool=True,
loss:Optional[Metric]=None, # loss used for saving the best model, generally the val loss
saved_loss:Optional[Metric]=None, # loss being saved
):
"""
Initilize the saver.
Parameters
----------
model : torch.nn.Module or a list of torch.nn.Module
A torch module to store. If the model is a list then it is considered as been [student,teacher]
optimizer : torch.optim.Optimizer
A torch optimizer.
path : str
Name of the model to store. The `.pth` extension is automatically added and the best model is stored with `_best.pth` extension.
every_epoch : int, default=2
Period to save the model.
save_best : bool, default=True
Whether to save the best model.
loss : biom3d.Metric
Loss function.
saved_loss : biom3d.Metric
Loss saved alongside the model.
"""
super().__init__()
self.model = model
self.optimizer = optimizer
self.path = path
self.path_last = path + '.pth'
self.path_best = path + '_best.pth'
self.every_epoch = every_epoch
self.save_best = save_best
self.best_loss = float('inf')
self.loss = loss
self.saved_loss = saved_loss
[docs]
def on_train_begin(self, epoch:int)->None:
"""
Call once at the beginning of training.
If model checkpoint files already exist, creates backups by copying them
to new filenames suffixed with the current epoch number to prevent overwriting.
Parameters
----------
epoch : int
The starting epoch number.
Returns
-------
None
"""
if os.path.exists(self.path_last):
copyfile(self.path_last, self.path + "_" + str(epoch) + ".pth")
if os.path.exists(self.path_best):
copyfile(self.path_best, self.path + "_" + str(epoch) + "_best.pth")
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Call at the end of each epoch.
Saves the model state, optimizer state, epoch number, and loss state.
Saves every `every_epoch` epochs, and optionally saves the best model
if the monitored loss improves.
Parameters
----------
epoch : int
The current epoch number.
Returns
-------
None
"""
if isinstance(self.model, list):
save_dict = {
'epoch': epoch,
'student': self.model[0].state_dict(),
'teacher': self.model[1].state_dict(),
'opt': self.optimizer.state_dict(),
'loss': self.saved_loss.state_dict(),
}
else:
save_dict = {
'epoch': epoch,
'model': self.model.state_dict(),
'opt': self.optimizer.state_dict(),
'loss': self.saved_loss.state_dict(),
}
if epoch % self.every_epoch == 0:
torch.save(save_dict, self.path_last)
print('Save model to {}'.format(self.path_last))
# save best model if needed
if self.save_best and (self.loss.avg < self.best_loss) and (self.loss.avg != 0):
save_dict['best_loss'] = self.loss.avg
torch.save(save_dict, self.path_best)
print('Save best model to {}'.format(self.path_best))
self.best_loss = self.loss.avg
[docs]
class LogSaver(Callback):
"""
Callback that logs training and validation metrics to a CSV file at the end of each epoch.
This logger writes a `log.csv` file in the specified log directory and appends:
- Epoch number
- Learning rate (if a scheduler is provided)
- Training and validation loss
- Training and validation metrics
:ivar str path: Full path to the log CSV file.
:ivar int crt_epoch: Current epoch (starts at 1).
:ivar biom3d.Metric train_loss: Training loss object.
:ivar biom3d.Metric val_loss: Validation loss object.
:ivar list[biom3d.Metric] train_metrics: List of training metrics.
:ivar list[biom3d.Metric] val_metrics: List of validation metrics.
:ivar Optional[Callback] scheduler: Scheduler providing current learning rate.
"""
[docs]
def __init__(self,
log_dir:str, # path to where the csv file will be store
train_loss:Metric,
val_loss:Optional[Metric]=None,
train_metrics:Optional[Metric]=None,
val_metrics:Optional[Metric]=None,
scheduler:Optional[Callback]=None,
):
"""
Initialize the log saver.
Parameters
----------
log_dir : str
Path to the directory where the CSV log file will be created.
train_loss : biom3d.Metric
Metric object tracking training loss.
val_loss : biom3d.Metric, optional
Metric object tracking validation loss.
train_metrics : list of biom3d.Metric, optional
List of training metrics to log.
val_metrics : list of biom3d.Metric, optional
List of validation metrics to log.
scheduler : Callback, optional
Learning rate scheduler for logging the current learning rate.
"""
self.path = os.path.join(log_dir,'log.csv')
self.crt_epoch = 1
self.train_loss = train_loss
self.val_loss = val_loss
self.train_metrics = train_metrics
self.val_metrics = val_metrics
self.scheduler = scheduler if hasattr(scheduler, 'get_last_lr') else None
f = open(self.path, "a")
if os.stat(self.path).st_size == 0: # if the file is empty
self.write_file_head()
f.close()
[docs]
def write_file_head(self)->None:
"""
Write the header of the CSV file with appropriate column names.
Returns
-------
None
"""
# write the head of the log file
f = open(self.path, "a")
head = "epoch"
if self.scheduler is not None:
head += ',learning_rate'
head += ",train_loss"
if self.val_loss is not None:
head += ",val_loss"
if self.train_metrics is not None:
for m in self.train_metrics:
head += "," + m.name
if self.val_metrics is not None:
for m in self.val_metrics:
head += "," + m.name
f.write(head + "\n")
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Write the header of the CSV file with appropriate column names.
Returns
-------
None
"""
f = open(self.path, "a")
template = str(epoch)
# add the scheduler value
if self.scheduler is not None:
template += "," + str(self.scheduler.get_last_lr()[0])
# add the learning rate and the training loss
template += "," + str(self.train_loss.avg.item()) # TODO: save the avg value only
# add the validation loss if needed
if self.val_loss is not None:
val_loss = self.val_loss.avg if isinstance(self.val_loss.avg,float) else self.val_loss.avg.item()
template += "," + str(val_loss)
# add the training metrics
if self.train_metrics is not None:
for m in self.train_metrics: template += "," + str(m.val.item())
# adde the validation metrics
if self.val_metrics is not None:
for m in self.val_metrics:
val_m = m.avg if isinstance(m.avg,int) else m.avg.item()
template += "," + str(val_m)
# write in the output file
f.write(template + "\n")
f.close()
[docs]
class ImageSaver(Callback):
"""
Callback that saves visual snapshots of input images, predictions, and ground truth masks at the end of selected epochs.
Typically used with 3D medical images. It slices through the middle channel of the input volume
and saves a 2D projection of input, predicted mask, and ground truth mask.
:ivar str image_dir: Path where images will be saved.
:ivar torch.nn.Module model: Model used for inference.
:ivar torch.utils.data.Dataloader val_dataloader: Validation dataloader providing batches for inference.
:ivar bool use_sigmoid: Whether to apply sigmoid (binary) or softmax (multiclass) on predictions.
:ivar int every_epoch: Frequency (in epochs) to save snapshots.
:ivar int plot_size: Number of images from the batch to visualize.
:ivar bool use_fp16: Whether to use AMP/mixed precision during inference.
"""
[docs]
def __init__(self,
image_dir: str,
model: torch.nn.Module,
val_dataloader: torch.utils.data.DataLoader,
use_sigmoid: bool = True,
every_epoch: int = 1,
plot_size: int = 1,
use_fp16: bool = True,
):
"""
Initialize the ImageSaver callback.
Parameters
----------
image_dir : str
Path to the directory where snapshots will be saved.
model : torch.nn.Module
Model used for making predictions.
val_dataloader : torch.utils.data.DataLoader or iter
Dataloader or iterable providing (input, target) batches.
use_sigmoid : bool, default=True
Whether to apply sigmoid (binary classification) or softmax (multiclass) to the predictions.
every_epoch : int, default=1
Snapshot saving frequency (in epochs).
plot_size : int, default=1
Number of examples from the batch to save in the snapshot.
use_fp16 : bool, default=True
Whether to enable AMP (automatic mixed precision) during inference.
"""
self.image_dir = image_dir
self.every_epoch = every_epoch
self.model = model
self.use_sigmoid = use_sigmoid
self.val_dataloader = val_dataloader
self.plot_size = plot_size
self.use_fp16 = use_fp16
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Call at the end of each epoch. Saves snapshots of model predictions for visual inspection.
Parameters
----------
epoch : int
The current epoch number.
Returns
-------
None
"""
if epoch % self.every_epoch == 0:
self.model.eval() # TODO: we can do a prediction function because this code is dupicated
with torch.no_grad():
for i in range(self.plot_size):
# make prediction
X, y = next(iter(self.val_dataloader))
if torch.cuda.is_available():
X, y = X.cuda(), y.cuda()
elif torch.backends.mps.is_available():
X, y = X.to('mps'), y.to('mps')
with torch.amp.autocast("cuda", enabled=self.use_fp16) if torch.cuda.is_available() else nullcontext():
pred = self.model(X)
if isinstance(pred,list):
pred = pred[-1]
if self.use_sigmoid:
pred = (torch.sigmoid(pred)>0.5).int()*255
else:
pred = (pred.softmax(dim=1).argmax(dim=1).unsqueeze(dim=1)).int()*255
l = [X, y, pred.detach()]
for j in range(len(l)):
_,_,channel,_,_ = l[j].shape
l[j] = l[j][-1, -1, channel//2, ...].cpu().numpy().astype(float)
X, y, pred = l
# plot
plt.figure(dpi=100)
# print the original
plt.subplot(self.plot_size,3,3*i+1)
plt.imshow(X)
plt.title('raw')
plt.axis('off')
# print the prediction
plt.subplot(self.plot_size,3,3*i+2)
plt.imshow(pred)
plt.title('pred')
plt.axis('off')
# print ground truth
plt.subplot(self.plot_size,3,3*i+3)
plt.imshow(y)
plt.title('ground_truth')
plt.axis('off')
del X, y, pred
im_path = os.path.join(self.image_dir,'image_' + str(epoch) + '.png')
print("Save image to {}".format(im_path))
plt.savefig(im_path)
plt.close()
[docs]
class TensorboardSaver(Callback):
"""
Callback to log losses and metrics to TensorBoard.
This callback plots training and validation losses as well as metrics during training.
Launch TensorBoard from the project directory with:
`tensorboard --logdir=logs/`
The following tags are used:
- Loss/train
- Loss/test
- Metrics/{name_of_metric}
:ivar SummaryWriter writer: TensorBoard summary writer.
:ivar Metric train_loss: Training loss.
:ivar Metric: Validation loss.
:ivar list[Metric] train_metrics: List of training metric modules.
:ivar list[Metric] val_metrics: List of validation metric modules.
:ivar int batch_size: Size of training mini-batch.
:ivar int n_batch_per_epoch: Number of batches per epoch.
:ivar int crt_epoch: Current epoch (used for iteration tracking).
"""
[docs]
def __init__(self,
log_dir:str,
train_loss:Metric,
val_loss:Metric,
train_metrics:list[Metric],
val_metrics:list[Metric],
batch_size:int,
n_batch_per_epoch:int):
"""
Initialize the TensorboardSaver callback.
Parameters
----------
log_dir : str
Path to the folder where TensorBoard logs will be saved.
train_loss : Metric
Training loss function.
val_loss : Metric
Validation loss function.
train_metrics : list of Metric
List of training metric objects (must expose .avg and .name).
val_metrics : list of Metric
List of validation metric objects (must expose .avg and .name).
batch_size : int
Mini-batch size, used to compute total number of images processed (x-axis).
n_batch_per_epoch : int
Number of batches in each epoch (used to track iteration count).
"""
self.writer = SummaryWriter(log_dir=log_dir)
self.train_loss = train_loss
self.val_loss = val_loss
self.train_metrics = train_metrics
self.val_metrics = val_metrics
self.batch_size = batch_size
self.n_batch_per_epoch = n_batch_per_epoch
self.crt_epoch = 0
[docs]
def on_epoch_begin(self, epoch:int)->None:
"""
Call before the start of an epoch.
Parameters
----------
epoch : int
Current epoch number.
Returns
-------
None
"""
self.crt_epoch = epoch
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Call after an epoch ends.
Logs the training and validation losses, as well as metrics, to TensorBoard.
Parameters
----------
epoch : int
Current epoch number.
Returns
-------
None
"""
n_iter = (epoch+1) * self.batch_size * self.n_batch_per_epoch
self.writer.add_scalar('Loss/train', self.train_loss.avg, n_iter)
if self.val_loss:
self.writer.add_scalar('Loss/test', self.val_loss.avg, n_iter)
if self.train_metrics:
for m in self.train_metrics: self.writer.add_scalar('Metrics/'+m.name,m.avg,n_iter)
if self.val_metrics:
for m in self.val_metrics: self.writer.add_scalar('Metrics/'+m.name,m.avg,n_iter)
#----------------------------------------------------------------------------
# Printer
[docs]
class LogPrinter(Callback):
"""
Callback used to print training logs to the terminal during training.
Logs the current epoch and periodically prints batch information with associated metrics and losses.
:ivar list[Metric] metrics: List of metric or loss modules to print (must implement __str__).
:ivar int nbof_epochs: Total number of epochs.
:ivar int nbof_batches: Number of batches per epoch.
:ivar int every_batch: Frequency (in batches) at which logs are printed.
"""
[docs]
def __init__(
self,
metrics:list[Metric],
nbof_epochs:int,
nbof_batches:int,
every_batch:int=10):
"""
Initialize the LogPrinter callback.
Parameters
----------
metrics : list of Metric
List of metrics to display. Losses can be included as well.
nbof_epochs : int
Total number of training epochs.
nbof_batches : int
Number of batches in each epoch.
every_batch : int, default=10
Print logs every `every_batch` batches.
"""
self.nbof_epochs = nbof_epochs
self.nbof_batches = nbof_batches
self.metrics = metrics
self.every_batch = every_batch
[docs]
def on_epoch_begin(self, epoch:int)->None:
"""
Call at the beginning of an epoch. Prints epoch progress.
Parameters
----------
epoch : int
Current epoch index (0-based).
Returns
-------
None
"""
print("Epoch [{:>3d}/{:>3d}]".format(epoch+1, self.nbof_epochs))
[docs]
def on_batch_end(self, batch:int)->None:
"""
Call at the end of a batch. Prints batch index and associated metrics.
Parameters
----------
batch : int
Index of the current batch.
Returns
-------
None
"""
if batch % self.every_batch == 0:
template = "Batch [{:>3d}/{:>3d}]".format(batch, self.nbof_batches)
for i in range(len(self.metrics)):
template += ", {}".format(self.metrics[i])
print(template)
#----------------------------------------------------------------------------
# schedulers
[docs]
class LRSchedulerMultiStep(Callback):
"""
Multi-step learning rate scheduler callback.
Wraps `torch.optim.lr_scheduler.MultiStepLR` to schedule learning rate decay
at specified epoch milestones. This scheduler multiplies the learning rate
by `gamma` whenever an epoch hits a milestone.
For more details, see:
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR
:ivar torch.optim.lr_scheduler.MultiStepLR scheduler: Internal PyTorch scheduler.
"""
[docs]
def __init__(self, optimizer:torch.optim.Optimizer, milestones:list[int], gamma:float=0.1):
"""
Initialize the MultiStepLR scheduler.
Parameters
----------
optimizer : torch.optim.Optimizer
The optimizer whose learning rate will be scheduled.
milestones : list of int
List of epoch indices at which to reduce the learning rate.
gamma : float, default=0.1
Multiplicative factor of learning rate decay.
"""
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma, verbose=False)
[docs]
def get_last_lr(self)->list[float]:
"""
Return the last computed learning rate by the scheduler.
Returns
-------
list of float
The most recent learning rates for each parameter group.
"""
return self.scheduler.get_last_lr()
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Call at the end of each epoch. Steps the learning rate scheduler.
Parameters
----------
epoch : int
The index of the current epoch.
Returns
-------
None
"""
self.scheduler.step()
print("Current learning rate: {}".format(self.scheduler.get_last_lr()))
[docs]
class LRSchedulerCosine(Callback):
"""
Cosine Annealing learning rate scheduler callback.
This callback wraps `torch.optim.lr_scheduler.CosineAnnealingLR` to apply a cosine
decay to the learning rate over a predefined number of epochs.
For more details, see:
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR
:ivar torch.optim.lr_scheduler.CosineAnnealingLR scheduler: Internal PyTorch scheduler.
"""
[docs]
def __init__(self,
optimizer: torch.optim.Optimizer,
T_max: int):
"""
Initialize the CosineAnnealingLR scheduler.
Parameters
----------
optimizer : torch.optim.Optimizer
The optimizer whose learning rate will be scheduled.
T_max : int
Maximum number of iterations (usually the total number of epochs).
"""
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=1e-6, verbose=False)
[docs]
def get_last_lr(self)->list[float]:
"""
Return the last computed learning rate by the scheduler.
Returns
-------
a list of float
The most recent learning rates for each parameter group.
"""
return self.scheduler.get_last_lr()
[docs]
def on_epoch_end(self, epoch:int)->None:
"""
Call at the end of each epoch. Steps the learning rate scheduler.
Parameters
----------
epoch : int
The index of the current epoch.
Returns
-------
None
"""
self.scheduler.step()
print("Current learning rate: {}".format(self.scheduler.get_last_lr()))
[docs]
class LRSchedulerPoly(Callback):
r"""
Polynomial learning rate scheduler.
This scheduler decreases the learning rate following a polynomial decay formula,
similar to what is used in nnU-Net:
.. math::
lr_{new} = lr_{initial} * (1 - \frac{epoch_{current}}{epoch_{max}})^{exponent}
:ivar float initial_lr: Initial learning rate.
:ivar int max_epochs: Total number of training epochs.
:ivar float exponent: Exponent controlling the decay rate.
:ivar torch.optim.Optimizer optimizer: Optimizer being scheduled.
"""
[docs]
def __init__(self,
optimizer:torch.optim.Optimizer,
initial_lr:float,
max_epochs:int,
exponent:float=0.9):
"""
Initialize the polynomial scheduler.
Parameters
----------
optimizer : torch.optim.Optimizer
Training optimizer whose learning rate will be scheduled.
initial_lr : float
Initial learning rate.
max_epochs : int
Total number of epochs for training.
exponent : float, optional, default=0.9
Decay exponent controlling how fast the learning rate decreases.
"""
self.initial_lr = initial_lr
self.max_epochs = max_epochs
self.exponent = exponent
self.optimizer = optimizer
[docs]
def on_epoch_begin(self, epoch:int)->None:
"""
Update the learning rate at the beginning of each epoch.
Parameters
----------
epoch : int
Current epoch index.
Returns
-------
None
"""
self.optimizer.param_groups[0]['lr'] = self.initial_lr * (1 - epoch / self.max_epochs)**self.exponent
print("Current learning rate: {}".format(self.optimizer.param_groups[0]['lr']))
[docs]
class ForceFGScheduler(Callback):
r"""
Foreground sampling rate scheduler using polynomial decay.
This scheduler gradually reduces the rate at which foreground patches are sampled
during training. It calls the `set_fg_rate` method of the dataset (accessed via the
dataloader) to adjust the foreground rate at each epoch.
Note: The dataset associated with the dataloader **must** implement the method:
`set_fg_rate(rate: float)`.
The decay follows this formula:
.. math::
fg_{rate} = (initial - min) * (1 - \frac{epoch}{max\_epochs})^{exponent} + min
:ivar torch.utils.data.DataLoader dataloader: Dataloader whose dataset supports `set_fg_rate`.
:ivar float initial_rate: Starting foreground sampling rate (e.g. 1.0 means only foreground).
:ivar float min_rate: Final minimal foreground rate (e.g. 0.33 as in nnU-Net).
:ivar int max_epochs: Total number of epochs.
:ivar float exponent: Exponent for polynomial decay.
"""
[docs]
def __init__(self,
dataloader:torch.utils.data.DataLoader,
initial_rate:float,
min_rate:float,
max_epochs:int,
exponent:float=0.9):
"""
Initialize the foreground scheduler.
Parameters
----------
dataloader : torch.utils.data.DataLoader
A dataloader whose dataset must implement `set_fg_rate(rate: float)`.
initial_rate : float
Initial sampling rate of foreground (typically 1.0).
min_rate : float
Final foreground sampling rate (e.g., 0.33).
max_epochs : int
Total number of training epochs.
exponent : float, optional, default=0.9
Polynomial decay exponent (between 0 and 1).
"""
self.dataloader = dataloader
self.initial_rate = initial_rate
self.min_rate = min_rate
self.max_epochs = max_epochs
self.exponent = exponent
[docs]
def on_epoch_begin(self, epoch:int)->None:
"""
Adjust foreground sampling rate at the beginning of an epoch.
Parameters
----------
epoch : int
Current epoch index.
Returns
-------
None
"""
crt_rate = (self.initial_rate-self.min_rate) * (1 - epoch / self.max_epochs)**self.exponent + self.min_rate
self.dataloader.dataset.set_fg_rate(crt_rate)
print("Current foreground rate: {}".format(crt_rate))
[docs]
class OverlapScheduler(Callback):
"""
Callback to schedule the minimum overlap rate between global and local patches using a polynomial decay.
The overlap rate is progressively reduced during training from an initial rate to a minimum final rate.
An overlap of 1 means the local patch is fully inside the global patch.
An overlap of 0 or less means the local patch can be outside the global patch.
The exact behavior depends on the dataset implementation.
:ivar torch.utils.data.DataLoader dataloader: DataLoader whose dataset implements `set_min_overlap`.
:ivar float initial_rate: Initial overlap rate at the start of training.
:ivar float min_rate: Minimum overlap rate at the end of training.
:ivar int max_epochs: Total number of training epochs.
:ivar float exponent: Exponent for the polynomial decay (between 0 and 1).
"""
[docs]
def __init__(
self,
dataloader: torch.utils.data.DataLoader,
initial_rate: float,
min_rate: float,
max_epochs: int,
exponent: float = 0.9,
):
"""
Initialize the OverlapScheduler callback.
Parameters
----------
dataloader : torch.utils.data.DataLoader
DataLoader whose dataset implements a `set_min_overlap` method.
initial_rate : float
Starting overlap rate.
min_rate : float
Final minimal overlap rate.
max_epochs : int
Total number of epochs for training.
exponent : float, default=0.9
Exponent controlling the polynomial decay curve.
"""
self.dataloader = dataloader
self.initial_rate = initial_rate
self.min_rate = min_rate
self.max_epochs = max_epochs
self.exponent = exponent
self.dataloader.dataset.set_min_overlap(self.initial_rate)
print("Current overlap: {}".format(self.initial_rate))
[docs]
def on_epoch_begin(self, epoch:int)->None:
"""
Adjust and set the minimum overlap rate at the start of an epoch.
Parameters
----------
epoch : int
Current epoch number (0-based).
Returns
-------
None
"""
crt_rate = (self.initial_rate-self.min_rate) * (1 - epoch / self.max_epochs)**self.exponent + self.min_rate
self.dataloader.dataset.set_min_overlap(crt_rate)
print("Current overlap: {}".format(crt_rate))
[docs]
class GlobalScaleScheduler(Callback):
"""
Callback to schedule the global crop scale using a polynomial decay.
The scale of the global crop is progressively reduced from the image size to the patch/local crop size.
:ivar torch.utils.data.DataLoader dataloader: DataLoader whose dataset implements `set_global_crop`.
:ivar float initial_rate: Initial global scale rate.
:ivar float min_rate: Minimal/final global scale rate.
:ivar int max_epochs: Total number of training epochs.
:ivar float exponent: Exponent for polynomial decay (between 0 and 1).
"""
[docs]
def __init__(
self,
dataloader: torch.utils.data.DataLoader,
initial_rate: float,
min_rate: float,
max_epochs: int,
exponent: float = 0.9,
):
"""
Initialize the GlobalScaleScheduler callback.
Parameters
----------
dataloader : torch.utils.data.DataLoader
DataLoader whose dataset implements a `set_global_crop` method.
initial_rate : float
Starting global scale rate.
min_rate : float
Final minimal global scale rate.
max_epochs : int
Total number of epochs.
exponent : float, default=0.9
Exponent controlling the polynomial decay.
"""
self.dataloader = dataloader
self.initial_rate = initial_rate
self.min_rate = min_rate
self.max_epochs = max_epochs
self.exponent = exponent
[docs]
def on_epoch_begin(self, epoch: int) -> None:
"""
Update the global crop scale at the beginning of an epoch.
Parameters
----------
epoch : int
Current epoch number (0-based).
Returns
-------
None
"""
crt_rate = (self.initial_rate-self.min_rate) * (1 - epoch / self.max_epochs)**self.exponent + self.min_rate
self.dataloader.dataset.set_global_crop(crt_rate)
print("Current global crop scale: {}".format(crt_rate))
[docs]
class WeightDecayScheduler(Callback):
"""
Callback to schedule weight decay during training, useful for DINO re-implementation.
:ivar torch.optim.Optimizer optimizer: Optimizer used for training.
:ivar float initial_wd: Initial weight decay value.
:ivar float final_wd: Final weight decay value.
:ivar int nb_epochs: Total number of training epochs.
:ivar bool use_poly: Whether to use polynomial scheduler (True) or cosine scheduler (False).
:ivar float exponent: Exponent for polynomial decay (between 0 and 1).
"""
[docs]
def __init__(
self,
optimizer: torch.optim.Optimizer,
initial_wd: float,
final_wd: float,
nb_epochs: int,
use_poly: bool = False,
exponent: float = 0.9,
):
"""
Initialize the WeightDecayScheduler callback.
Parameters
----------
optimizer : torch.optim.Optimizer
Optimizer whose weight decay parameter will be scheduled.
initial_wd : float
Starting weight decay value.
final_wd : float
Final weight decay value.
nb_epochs : int
Total number of training epochs.
use_poly : bool, default=False
Use polynomial decay if True; cosine decay otherwise.
exponent : float, default=0.9
Exponent controlling polynomial decay (only used if use_poly=True).
"""
self.optimizer = optimizer
self.initial_wd = initial_wd
self.final_wd = final_wd
self.nb_epochs = nb_epochs
self.use_poly = use_poly
self.exponent = exponent
[docs]
def on_epoch_begin(self, epoch: int) -> None:
"""
Update weight decay at the start of each epoch.
Parameters
----------
epoch : int
Current epoch number (0-based).
Returns
-------
None
"""
if self.use_poly:
self.optimizer.param_groups[0]["weight_decay"] = self.final_wd + (self.initial_wd - self.final_wd) * (1 - epoch / self.nb_epochs)**self.exponent
else:
self.optimizer.param_groups[0]["weight_decay"] = self.final_wd + 0.5 * (self.initial_wd - self.final_wd) * (1 + np.cos(np.pi * epoch / self.nb_epochs))
print("Current weight decay:", self.optimizer.param_groups[0]["weight_decay"] )
[docs]
class MomentumScheduler(Callback):
"""
Momentum scheduler for DINO re-implementation.
Schedules momentum with different modes: polynomial, exponential, linear, or cosine decay.
:ivar torch.optim.Optimizer optimizer: Optimizer (optional, if needed).
:ivar float initial_momentum: Initial momentum value.
:ivar float final_momentum: Final momentum value.
:ivar int nb_epochs: Total number of epochs.
:ivar str mode: Scheduling mode ('poly', 'exp', 'linear', or None for cosine).
:ivar float exponent: Exponent for polynomial or exponential decay (between 0 and 1).
:ivar float crt_momentum: Current momentum value.
"""
[docs]
def __init__(
self,
initial_momentum: float,
final_momentum: float,
nb_epochs: int,
mode: Literal['poly','exp','linear'] | None = 'poly',
exponent: float = 0.9
):
"""
Initialize the MomentumScheduler.
Parameters
----------
initial_momentum : float
Starting momentum value.
final_momentum : float
Ending momentum value.
nb_epochs : int
Total number of training epochs.
mode : 'poly','exp','linear' or None, default='poly'
Scheduling mode: 'poly', 'exp', 'linear', or None (for cosine).
exponent : float, default=0.9
Exponent controlling polynomial or exponential decay.
"""
self.initial_momentum = initial_momentum
self.final_momentum = final_momentum
self.nb_epochs = nb_epochs
self.crt_momentum = None
self.mode = mode
self.exponent = exponent
def __getitem__(self, epoch: int) -> float:
"""
Compute momentum value at given epoch.
Parameters
----------
epoch : int
Current epoch number (0-based).
Returns
-------
float
Computed momentum value.
"""
if self.mode=='poly':
self.crt_momentum = self.final_momentum + (self.initial_momentum - self.final_momentum) * (1 - epoch / self.nb_epochs)**self.exponent
elif self.mode=='exp':
self.crt_momentum = self.final_momentum + (self.initial_momentum - self.final_momentum) * (np.exp(-self.exponent*epoch / self.nb_epochs))
elif self.mode=='linear':
if epoch > self.nb_epochs:
self.crt_momentum = self.final_momentum
else:
self.crt_momentum = self.initial_momentum + (self.final_momentum - self.initial_momentum) * epoch / self.nb_epochs
else: # cosine
self.crt_momentum = self.final_momentum + 0.5 * (self.initial_momentum - self.final_momentum) * (1 + np.cos(np.pi * epoch / self.nb_epochs))
return self.crt_momentum
[docs]
def on_epoch_end(self, epoch: int) -> None:
"""
Call at the end of each epoch.
Parameters
----------
epoch : int
Current epoch number.
Returns
-------
None
"""
print("Current teacher momentum:", self.crt_momentum)
[docs]
class DatasetSizeScheduler(Callback):
"""
Dataset size scheduler.
Progressively increases the size of the dataset to aid Arcface training.
The dataloader's dataset must implement a `set_dataset_size` method.
The model must implement a `set_num_classes` method.
At each epoch, the dataset size is incremented by 1 (up to max_dataset_size).
:ivar torch.utils.data.DataLoader dataloader: Dataloader with dataset implementing `set_dataset_size`.
:ivar torch.nn.Module model: Model implementing `set_num_classes`.
:ivar int max_dataset_size: Maximum dataset size.
:ivar int min_dataset_size: Minimum dataset size at start of training.
Notes
-----
No proof that this improves final performance.
"""
[docs]
def __init__(self,
dataloader:torch.utils.data.DataLoader,
model:torch.nn.Module,
max_dataset_size:int,
min_dataset_size:int=5):
"""
Initialize the DatasetSizeScheduler.
Parameters
----------
dataloader : torch.utils.data.DataLoader
Dataloader whose dataset must have a `set_dataset_size` method.
model : torch.nn.Module
Model with a `set_num_classes` method.
max_dataset_size : int
Maximum size of the dataset.
min_dataset_size : int, default=5
Initial minimal dataset size.
"""
self.dataloader = dataloader
self.model = model
self.max_dataset_size = max_dataset_size
self.min_dataset_size = min_dataset_size
[docs]
def on_epoch_begin(self, epoch: int) -> None:
"""
Call at the beginning of an epoch.
Increases the dataset size and updates the model number of classes.
Parameters
----------
epoch : int
Current epoch index (0-based).
Returns
-------
None
"""
dataset_size = min(epoch+self.min_dataset_size, self.max_dataset_size)
self.dataloader.dataset.set_dataset_size(dataset_size)
self.model.set_num_classes(dataset_size)
print("Current dataset size: {}".format(dataset_size))
#----------------------------------------------------------------------------
# metrics updater
[docs]
class MetricsUpdater(Callback):
"""
Update the metrics averages by calling the `update` method of each metric.
:ivar list[Metric] metrics: List of metrics and losses to update.
:ivar int batch_size: Batch size used to update metrics.
"""
[docs]
def __init__(self, metrics:list[Metric], batch_size:int):
"""
Initialize the MetricsUpdater callback.
Parameters
----------
metrics : list of biom3d.Metric
List of metrics and losses to update.
batch_size : int
Batch size.
"""
self.metrics = metrics
self.batch_size = batch_size
[docs]
def on_epoch_begin(self, epoch:Optional[int]=None)->None:
"""
Call at the beginning of an epoch.
Resets all metrics.
Parameters
----------
epoch : int, optional
Current epoch index (not used).
Returns
-------
None
"""
for m in self.metrics: m.reset()
[docs]
def on_batch_end(self, batch: Optional[int] = None) -> None:
"""
Call at the end of a batch.
Updates all metrics with the batch size.
Parameters
----------
batch : int, optional
Current batch index (not used).
Returns
-------
None
"""
for m in self.metrics: m.update(self.batch_size)
#----------------------------------------------------------------------------