Callbacks

The Callbacks are routine called periodically during the training. Simply inherit from the biom3d.callback.Callback class and override one on its methods to create a new callback. Then add it to the biom3d.builder.Builder.build_callback method.

Callback are periodically called during training.

There are currently 3 different periods: - training - epoch - batch

class biom3d.callbacks.Callback[source]

Abstract base class used to build new callbacks.

Callback are periodically called during training. There are currently 3 different periods: - training - epoch - batch Callbacks are called either before or after a period.

Each method starting by on_ can be overridden. This method will be called at a certain time point during the training process. For instance, the on_epoch_end method will be called in the end of each epoch.

:ivar list[biom3d.Metric] metrics

__init__()[source]

Initilization of attributes.

abstractmethod on_batch_begin(batch: int) None[source]

Call before processing a batch.

Parameters:

batch (int) – The batch index.

Return type:

None

abstractmethod on_batch_end(batch: int) None[source]

Call after processing a batch.

Parameters:

batch (int) – The batch index.

Return type:

None

abstractmethod on_epoch_begin(epoch: int) None[source]

Call before processing an epoch.

Parameters:

epoch (int) – The epoch index.

Return type:

None

abstractmethod on_epoch_end(epoch: int) None[source]

Call after processing an epoch.

Parameters:

epoch (int) – The epoch index.

Return type:

None

abstractmethod on_train_begin(epoch: int | None = None) None[source]

Call once at the beginning of training.

Parameters:

epoch (int) – The epoch index.

Return type:

None

abstractmethod on_train_end(epoch: int | None = None) None[source]

Call once at the end of training.

Parameters:

epoch (int) – The epoch index.

Return type:

None

set_trainer(metrics: list[Metric]) None[source]

Associate metrics or training state to this callback.

Parameters:

metrics (list of biom3d.Metric) – list of used metrics or training information to be used by the callback.

Return type:

None

class biom3d.callbacks.Callbacks(callbacks: dict[str, Any])[source]

Aggregates and manages multiple callbacks, dispatching events to all contained callbacks.

This class acts as a container for multiple Callback instances. It forwards all callback events (on_batch_begin, on_batch_end, on_epoch_begin, on_epoch_end, on_train_begin, on_train_end) to each callback in the collection.

This allows users to easily manage several callbacks in a modular way, for example combining logging, checkpoint saving, metric computation, etc., all in one object.

Variables:

callbacks (dict[str,Callback]) – A dictionary of callback.

Examples

>>> cbs = Callbacks({
...     "logger": LogSaver(...),
...     "saver": ModelSaver(...),
... })
>>> cbs.on_epoch_end(10)  # calls on_epoch_end(10) on both LogSaver and ModelSaver
__init__(callbacks: dict[str, Any])[source]

Initialize from a dictionary.

Parameters:

callbacks (dict of str to biom3d.callback.Callback) – A dictionary of callbacks.

on_batch_begin(batch: int) None[source]

Call before processing a batch. Forwards the call to all callbacks.

Parameters:

batch (int) – Current batch index.

Return type:

None

on_batch_end(batch: int) None[source]

Call after processing a batch. Forwards the call to all callbacks.

Parameters:

batch (int) – Current batch index.

Return type:

None

on_epoch_begin(epoch: int) None[source]

Call before processing an epoch. Forwards the call to all callbacks.

Parameters:

epoch (int) – Current epoch index.

Return type:

None

on_epoch_end(epoch: int) None[source]

Call after processing an epoch. Forwards the call to all callbacks.

Parameters:

epoch (int) – Current epoch index.

Return type:

None

on_train_begin(epoch: int | None = None) None[source]

Call once at the start of training. Forwards the call to all callbacks.

Parameters:

epoch (int, optional) – Starting epoch number, if any.

Return type:

None

on_train_end(epoch: int | None = None) None[source]

Call once at the end of training. Forwards the call to all callbacks.

Parameters:

epoch (int, optional) – Final epoch number, if any.

Return type:

None

set_trainer(trainer: list[Metric]) None[source]

Set the trainer or metrics context for each callback.

Parameters:

trainer (list of biom3d.Metric) – Trainer or metrics list to pass to callbacks.

Return type:

None

class biom3d.callbacks.DatasetSizeScheduler(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module, max_dataset_size: int, min_dataset_size: int = 5)[source]

Dataset size scheduler.

Progressively increases the size of the dataset to aid Arcface training. The dataloader’s dataset must implement a set_dataset_size method. The model must implement a set_num_classes method. At each epoch, the dataset size is incremented by 1 (up to max_dataset_size).

Variables:
  • dataloader (torch.utils.data.DataLoader) – Dataloader with dataset implementing set_dataset_size.

  • model (torch.nn.Module) – Model implementing set_num_classes.

  • max_dataset_size (int) – Maximum dataset size.

  • min_dataset_size (int) – Minimum dataset size at start of training.

Notes

No proof that this improves final performance.

__init__(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module, max_dataset_size: int, min_dataset_size: int = 5)[source]

Initialize the DatasetSizeScheduler.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – Dataloader whose dataset must have a set_dataset_size method.

  • model (torch.nn.Module) – Model with a set_num_classes method.

  • max_dataset_size (int) – Maximum size of the dataset.

  • min_dataset_size (int, default=5) – Initial minimal dataset size.

on_epoch_begin(epoch: int) None[source]

Call at the beginning of an epoch.

Increases the dataset size and updates the model number of classes.

Parameters:

epoch (int) – Current epoch index (0-based).

Return type:

None

class biom3d.callbacks.ForceFGScheduler(dataloader: torch.utils.data.DataLoader, initial_rate: float, min_rate: float, max_epochs: int, exponent: float = 0.9)[source]

Foreground sampling rate scheduler using polynomial decay.

This scheduler gradually reduces the rate at which foreground patches are sampled during training. It calls the set_fg_rate method of the dataset (accessed via the dataloader) to adjust the foreground rate at each epoch.

Note: The dataset associated with the dataloader must implement the method: set_fg_rate(rate: float).

The decay follows this formula:

\[fg_{rate} = (initial - min) * (1 - \frac{epoch}{max\_epochs})^{exponent} + min\]
Variables:
  • dataloader (torch.utils.data.DataLoader) – Dataloader whose dataset supports set_fg_rate.

  • initial_rate (float) – Starting foreground sampling rate (e.g. 1.0 means only foreground).

  • min_rate (float) – Final minimal foreground rate (e.g. 0.33 as in nnU-Net).

  • max_epochs (int) – Total number of epochs.

  • exponent (float) – Exponent for polynomial decay.

__init__(dataloader: torch.utils.data.DataLoader, initial_rate: float, min_rate: float, max_epochs: int, exponent: float = 0.9)[source]

Initialize the foreground scheduler.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – A dataloader whose dataset must implement set_fg_rate(rate: float).

  • initial_rate (float) – Initial sampling rate of foreground (typically 1.0).

  • min_rate (float) – Final foreground sampling rate (e.g., 0.33).

  • max_epochs (int) – Total number of training epochs.

  • exponent (float, optional, default=0.9) – Polynomial decay exponent (between 0 and 1).

on_epoch_begin(epoch: int) None[source]

Adjust foreground sampling rate at the beginning of an epoch.

Parameters:

epoch (int) – Current epoch index.

Return type:

None

class biom3d.callbacks.GlobalScaleScheduler(dataloader: torch.utils.data.DataLoader, initial_rate: float, min_rate: float, max_epochs: int, exponent: float = 0.9)[source]

Callback to schedule the global crop scale using a polynomial decay.

The scale of the global crop is progressively reduced from the image size to the patch/local crop size.

Variables:
  • dataloader (torch.utils.data.DataLoader) – DataLoader whose dataset implements set_global_crop.

  • initial_rate (float) – Initial global scale rate.

  • min_rate (float) – Minimal/final global scale rate.

  • max_epochs (int) – Total number of training epochs.

  • exponent (float) – Exponent for polynomial decay (between 0 and 1).

__init__(dataloader: torch.utils.data.DataLoader, initial_rate: float, min_rate: float, max_epochs: int, exponent: float = 0.9)[source]

Initialize the GlobalScaleScheduler callback.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – DataLoader whose dataset implements a set_global_crop method.

  • initial_rate (float) – Starting global scale rate.

  • min_rate (float) – Final minimal global scale rate.

  • max_epochs (int) – Total number of epochs.

  • exponent (float, default=0.9) – Exponent controlling the polynomial decay.

on_epoch_begin(epoch: int) None[source]

Update the global crop scale at the beginning of an epoch.

Parameters:

epoch (int) – Current epoch number (0-based).

Return type:

None

class biom3d.callbacks.ImageSaver(image_dir: str, model: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader, use_sigmoid: bool = True, every_epoch: int = 1, plot_size: int = 1, use_fp16: bool = True)[source]

Callback that saves visual snapshots of input images, predictions, and ground truth masks at the end of selected epochs.

Typically used with 3D medical images. It slices through the middle channel of the input volume and saves a 2D projection of input, predicted mask, and ground truth mask.

Variables:
  • image_dir (str) – Path where images will be saved.

  • model (torch.nn.Module) – Model used for inference.

  • val_dataloader (torch.utils.data.Dataloader) – Validation dataloader providing batches for inference.

  • use_sigmoid (bool) – Whether to apply sigmoid (binary) or softmax (multiclass) on predictions.

  • every_epoch (int) – Frequency (in epochs) to save snapshots.

  • plot_size (int) – Number of images from the batch to visualize.

  • use_fp16 (bool) – Whether to use AMP/mixed precision during inference.

__init__(image_dir: str, model: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader, use_sigmoid: bool = True, every_epoch: int = 1, plot_size: int = 1, use_fp16: bool = True)[source]

Initialize the ImageSaver callback.

Parameters:
  • image_dir (str) – Path to the directory where snapshots will be saved.

  • model (torch.nn.Module) – Model used for making predictions.

  • val_dataloader (torch.utils.data.DataLoader or iter) – Dataloader or iterable providing (input, target) batches.

  • use_sigmoid (bool, default=True) – Whether to apply sigmoid (binary classification) or softmax (multiclass) to the predictions.

  • every_epoch (int, default=1) – Snapshot saving frequency (in epochs).

  • plot_size (int, default=1) – Number of examples from the batch to save in the snapshot.

  • use_fp16 (bool, default=True) – Whether to enable AMP (automatic mixed precision) during inference.

on_epoch_end(epoch: int) None[source]

Call at the end of each epoch. Saves snapshots of model predictions for visual inspection.

Parameters:

epoch (int) – The current epoch number.

Return type:

None

class biom3d.callbacks.LRSchedulerCosine(optimizer: torch.optim.Optimizer, T_max: int)[source]

Cosine Annealing learning rate scheduler callback.

This callback wraps torch.optim.lr_scheduler.CosineAnnealingLR to apply a cosine decay to the learning rate over a predefined number of epochs.

For more details, see: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR

Variables:

scheduler (torch.optim.lr_scheduler.CosineAnnealingLR) – Internal PyTorch scheduler.

__init__(optimizer: torch.optim.Optimizer, T_max: int)[source]

Initialize the CosineAnnealingLR scheduler.

Parameters:
  • optimizer (torch.optim.Optimizer) – The optimizer whose learning rate will be scheduled.

  • T_max (int) – Maximum number of iterations (usually the total number of epochs).

get_last_lr() list[float][source]

Return the last computed learning rate by the scheduler.

Returns:

The most recent learning rates for each parameter group.

Return type:

a list of float

on_epoch_end(epoch: int) None[source]

Call at the end of each epoch. Steps the learning rate scheduler.

Parameters:

epoch (int) – The index of the current epoch.

Return type:

None

class biom3d.callbacks.LRSchedulerMultiStep(optimizer: torch.optim.Optimizer, milestones: list[int], gamma: float = 0.1)[source]

Multi-step learning rate scheduler callback.

Wraps torch.optim.lr_scheduler.MultiStepLR to schedule learning rate decay at specified epoch milestones. This scheduler multiplies the learning rate by gamma whenever an epoch hits a milestone.

For more details, see: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR

Variables:

scheduler (torch.optim.lr_scheduler.MultiStepLR) – Internal PyTorch scheduler.

__init__(optimizer: torch.optim.Optimizer, milestones: list[int], gamma: float = 0.1)[source]

Initialize the MultiStepLR scheduler.

Parameters:
  • optimizer (torch.optim.Optimizer) – The optimizer whose learning rate will be scheduled.

  • milestones (list of int) – List of epoch indices at which to reduce the learning rate.

  • gamma (float, default=0.1) – Multiplicative factor of learning rate decay.

get_last_lr() list[float][source]

Return the last computed learning rate by the scheduler.

Returns:

The most recent learning rates for each parameter group.

Return type:

list of float

on_epoch_end(epoch: int) None[source]

Call at the end of each epoch. Steps the learning rate scheduler.

Parameters:

epoch (int) – The index of the current epoch.

Return type:

None

class biom3d.callbacks.LRSchedulerPoly(optimizer: torch.optim.Optimizer, initial_lr: float, max_epochs: int, exponent: float = 0.9)[source]

Polynomial learning rate scheduler.

This scheduler decreases the learning rate following a polynomial decay formula, similar to what is used in nnU-Net:

\[lr_{new} = lr_{initial} * (1 - \frac{epoch_{current}}{epoch_{max}})^{exponent}\]
Variables:
  • initial_lr (float) – Initial learning rate.

  • max_epochs (int) – Total number of training epochs.

  • exponent (float) – Exponent controlling the decay rate.

  • optimizer (torch.optim.Optimizer) – Optimizer being scheduled.

__init__(optimizer: torch.optim.Optimizer, initial_lr: float, max_epochs: int, exponent: float = 0.9)[source]

Initialize the polynomial scheduler.

Parameters:
  • optimizer (torch.optim.Optimizer) – Training optimizer whose learning rate will be scheduled.

  • initial_lr (float) – Initial learning rate.

  • max_epochs (int) – Total number of epochs for training.

  • exponent (float, optional, default=0.9) – Decay exponent controlling how fast the learning rate decreases.

on_epoch_begin(epoch: int) None[source]

Update the learning rate at the beginning of each epoch.

Parameters:

epoch (int) – Current epoch index.

Return type:

None

class biom3d.callbacks.LogPrinter(metrics: list[Metric], nbof_epochs: int, nbof_batches: int, every_batch: int = 10)[source]

Callback used to print training logs to the terminal during training.

Logs the current epoch and periodically prints batch information with associated metrics and losses.

Variables:
  • metrics (list[Metric]) – List of metric or loss modules to print (must implement __str__).

  • nbof_epochs (int) – Total number of epochs.

  • nbof_batches (int) – Number of batches per epoch.

  • every_batch (int) – Frequency (in batches) at which logs are printed.

__init__(metrics: list[Metric], nbof_epochs: int, nbof_batches: int, every_batch: int = 10)[source]

Initialize the LogPrinter callback.

Parameters:
  • metrics (list of Metric) – List of metrics to display. Losses can be included as well.

  • nbof_epochs (int) – Total number of training epochs.

  • nbof_batches (int) – Number of batches in each epoch.

  • every_batch (int, default=10) – Print logs every every_batch batches.

on_batch_end(batch: int) None[source]

Call at the end of a batch. Prints batch index and associated metrics.

Parameters:

batch (int) – Index of the current batch.

Return type:

None

on_epoch_begin(epoch: int) None[source]

Call at the beginning of an epoch. Prints epoch progress.

Parameters:

epoch (int) – Current epoch index (0-based).

Return type:

None

class biom3d.callbacks.LogSaver(log_dir: str, train_loss: Metric, val_loss: Metric | None = None, train_metrics: Metric | None = None, val_metrics: Metric | None = None, scheduler: Callback | None = None)[source]

Callback that logs training and validation metrics to a CSV file at the end of each epoch.

This logger writes a log.csv file in the specified log directory and appends: - Epoch number - Learning rate (if a scheduler is provided) - Training and validation loss - Training and validation metrics

Variables:
  • path (str) – Full path to the log CSV file.

  • crt_epoch (int) – Current epoch (starts at 1).

  • train_loss (biom3d.Metric) – Training loss object.

  • val_loss (biom3d.Metric) – Validation loss object.

  • train_metrics (list[biom3d.Metric]) – List of training metrics.

  • val_metrics (list[biom3d.Metric]) – List of validation metrics.

  • scheduler (Optional[Callback]) – Scheduler providing current learning rate.

__init__(log_dir: str, train_loss: Metric, val_loss: Metric | None = None, train_metrics: Metric | None = None, val_metrics: Metric | None = None, scheduler: Callback | None = None)[source]

Initialize the log saver.

Parameters:
  • log_dir (str) – Path to the directory where the CSV log file will be created.

  • train_loss (biom3d.Metric) – Metric object tracking training loss.

  • val_loss (biom3d.Metric, optional) – Metric object tracking validation loss.

  • train_metrics (list of biom3d.Metric, optional) – List of training metrics to log.

  • val_metrics (list of biom3d.Metric, optional) – List of validation metrics to log.

  • scheduler (Callback, optional) – Learning rate scheduler for logging the current learning rate.

on_epoch_end(epoch: int) None[source]

Write the header of the CSV file with appropriate column names.

Return type:

None

write_file_head() None[source]

Write the header of the CSV file with appropriate column names.

Return type:

None

class biom3d.callbacks.MetricsUpdater(metrics: list[Metric], batch_size: int)[source]

Update the metrics averages by calling the update method of each metric.

Variables:
  • metrics (list[Metric]) – List of metrics and losses to update.

  • batch_size (int) – Batch size used to update metrics.

__init__(metrics: list[Metric], batch_size: int)[source]

Initialize the MetricsUpdater callback.

Parameters:
  • metrics (list of biom3d.Metric) – List of metrics and losses to update.

  • batch_size (int) – Batch size.

on_batch_end(batch: int | None = None) None[source]

Call at the end of a batch.

Updates all metrics with the batch size.

Parameters:

batch (int, optional) – Current batch index (not used).

Return type:

None

on_epoch_begin(epoch: int | None = None) None[source]

Call at the beginning of an epoch.

Resets all metrics.

Parameters:

epoch (int, optional) – Current epoch index (not used).

Return type:

None

class biom3d.callbacks.ModelSaver(model: torch.nn.Module, optimizer: torch.optim.Optimizer, path: str = 'unet', every_epoch: int = 2, save_best: bool = True, loss: Metric | None = None, saved_loss: Metric | None = None)[source]

Saves the model, optimizer state, epoch, and loss at the end of each epoch.

Can also save the best model based on a monitored loss metric.

Variables:
  • model (torch.nn.Module | list[torch.nn.Module]) – Modele to store, if list it is assumed to be [student,teacher].

  • optimizer (torch.optim.Optimizer) – torch optimizer.

  • path (str) – Name of the file representing the model.

  • path_last (str) – path + ‘.pth’

  • path_best (str) – path + ‘_best.pth’

  • every_epoch (int) – Period between save

  • save_best (bool) – Whether to save best or not.

  • best_loss (float) – Best loss value since beginning.

  • loss (biom3d.Metric) – Loss function.

  • saved_loss (biom3d.Metric) – Loss function to save alongside model.

__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, path: str = 'unet', every_epoch: int = 2, save_best: bool = True, loss: Metric | None = None, saved_loss: Metric | None = None)[source]

Initilize the saver.

Parameters:
  • model (torch.nn.Module or a list of torch.nn.Module) – A torch module to store. If the model is a list then it is considered as been [student,teacher]

  • optimizer (torch.optim.Optimizer) – A torch optimizer.

  • path (str) – Name of the model to store. The .pth extension is automatically added and the best model is stored with _best.pth extension.

  • every_epoch (int, default=2) – Period to save the model.

  • save_best (bool, default=True) – Whether to save the best model.

  • loss (biom3d.Metric) – Loss function.

  • saved_loss (biom3d.Metric) – Loss saved alongside the model.

on_epoch_end(epoch: int) None[source]

Call at the end of each epoch.

Saves the model state, optimizer state, epoch number, and loss state. Saves every every_epoch epochs, and optionally saves the best model if the monitored loss improves.

Parameters:

epoch (int) – The current epoch number.

Return type:

None

on_train_begin(epoch: int) None[source]

Call once at the beginning of training.

If model checkpoint files already exist, creates backups by copying them to new filenames suffixed with the current epoch number to prevent overwriting.

Parameters:

epoch (int) – The starting epoch number.

Return type:

None

class biom3d.callbacks.MomentumScheduler(initial_momentum: float, final_momentum: float, nb_epochs: int, mode: Literal['poly', 'exp', 'linear'] | None = 'poly', exponent: float = 0.9)[source]

Momentum scheduler for DINO re-implementation.

Schedules momentum with different modes: polynomial, exponential, linear, or cosine decay.

Variables:
  • optimizer (torch.optim.Optimizer) – Optimizer (optional, if needed).

  • initial_momentum (float) – Initial momentum value.

  • final_momentum (float) – Final momentum value.

  • nb_epochs (int) – Total number of epochs.

  • mode (str) – Scheduling mode (‘poly’, ‘exp’, ‘linear’, or None for cosine).

  • exponent (float) – Exponent for polynomial or exponential decay (between 0 and 1).

  • crt_momentum (float) – Current momentum value.

__init__(initial_momentum: float, final_momentum: float, nb_epochs: int, mode: Literal['poly', 'exp', 'linear'] | None = 'poly', exponent: float = 0.9)[source]

Initialize the MomentumScheduler.

Parameters:
  • initial_momentum (float) – Starting momentum value.

  • final_momentum (float) – Ending momentum value.

  • nb_epochs (int) – Total number of training epochs.

  • mode ('poly','exp','linear' or None, default='poly') – Scheduling mode: ‘poly’, ‘exp’, ‘linear’, or None (for cosine).

  • exponent (float, default=0.9) – Exponent controlling polynomial or exponential decay.

on_epoch_end(epoch: int) None[source]

Call at the end of each epoch.

Parameters:

epoch (int) – Current epoch number.

Return type:

None

class biom3d.callbacks.OverlapScheduler(dataloader: torch.utils.data.DataLoader, initial_rate: float, min_rate: float, max_epochs: int, exponent: float = 0.9)[source]

Callback to schedule the minimum overlap rate between global and local patches using a polynomial decay.

The overlap rate is progressively reduced during training from an initial rate to a minimum final rate. An overlap of 1 means the local patch is fully inside the global patch. An overlap of 0 or less means the local patch can be outside the global patch. The exact behavior depends on the dataset implementation.

Variables:
  • dataloader (torch.utils.data.DataLoader) – DataLoader whose dataset implements set_min_overlap.

  • initial_rate (float) – Initial overlap rate at the start of training.

  • min_rate (float) – Minimum overlap rate at the end of training.

  • max_epochs (int) – Total number of training epochs.

  • exponent (float) – Exponent for the polynomial decay (between 0 and 1).

__init__(dataloader: torch.utils.data.DataLoader, initial_rate: float, min_rate: float, max_epochs: int, exponent: float = 0.9)[source]

Initialize the OverlapScheduler callback.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – DataLoader whose dataset implements a set_min_overlap method.

  • initial_rate (float) – Starting overlap rate.

  • min_rate (float) – Final minimal overlap rate.

  • max_epochs (int) – Total number of epochs for training.

  • exponent (float, default=0.9) – Exponent controlling the polynomial decay curve.

on_epoch_begin(epoch: int) None[source]

Adjust and set the minimum overlap rate at the start of an epoch.

Parameters:

epoch (int) – Current epoch number (0-based).

Return type:

None

class biom3d.callbacks.TensorboardSaver(log_dir: str, train_loss: Metric, val_loss: Metric, train_metrics: list[Metric], val_metrics: list[Metric], batch_size: int, n_batch_per_epoch: int)[source]

Callback to log losses and metrics to TensorBoard.

This callback plots training and validation losses as well as metrics during training. Launch TensorBoard from the project directory with: tensorboard –logdir=logs/

The following tags are used: - Loss/train - Loss/test - Metrics/{name_of_metric}

Variables:
  • writer (SummaryWriter) – TensorBoard summary writer.

  • train_loss (Metric) – Training loss.

  • Metric – Validation loss.

  • train_metrics (list[Metric]) – List of training metric modules.

  • val_metrics (list[Metric]) – List of validation metric modules.

  • batch_size (int) – Size of training mini-batch.

  • n_batch_per_epoch (int) – Number of batches per epoch.

  • crt_epoch (int) – Current epoch (used for iteration tracking).

__init__(log_dir: str, train_loss: Metric, val_loss: Metric, train_metrics: list[Metric], val_metrics: list[Metric], batch_size: int, n_batch_per_epoch: int)[source]

Initialize the TensorboardSaver callback.

Parameters:
  • log_dir (str) – Path to the folder where TensorBoard logs will be saved.

  • train_loss (Metric) – Training loss function.

  • val_loss (Metric) – Validation loss function.

  • train_metrics (list of Metric) – List of training metric objects (must expose .avg and .name).

  • val_metrics (list of Metric) – List of validation metric objects (must expose .avg and .name).

  • batch_size (int) – Mini-batch size, used to compute total number of images processed (x-axis).

  • n_batch_per_epoch (int) – Number of batches in each epoch (used to track iteration count).

on_epoch_begin(epoch: int) None[source]

Call before the start of an epoch.

Parameters:

epoch (int) – Current epoch number.

Return type:

None

on_epoch_end(epoch: int) None[source]

Call after an epoch ends.

Logs the training and validation losses, as well as metrics, to TensorBoard.

Parameters:

epoch (int) – Current epoch number.

Return type:

None

class biom3d.callbacks.WeightDecayScheduler(optimizer: torch.optim.Optimizer, initial_wd: float, final_wd: float, nb_epochs: int, use_poly: bool = False, exponent: float = 0.9)[source]

Callback to schedule weight decay during training, useful for DINO re-implementation.

Variables:
  • optimizer (torch.optim.Optimizer) – Optimizer used for training.

  • initial_wd (float) – Initial weight decay value.

  • final_wd (float) – Final weight decay value.

  • nb_epochs (int) – Total number of training epochs.

  • use_poly (bool) – Whether to use polynomial scheduler (True) or cosine scheduler (False).

  • exponent (float) – Exponent for polynomial decay (between 0 and 1).

__init__(optimizer: torch.optim.Optimizer, initial_wd: float, final_wd: float, nb_epochs: int, use_poly: bool = False, exponent: float = 0.9)[source]

Initialize the WeightDecayScheduler callback.

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer whose weight decay parameter will be scheduled.

  • initial_wd (float) – Starting weight decay value.

  • final_wd (float) – Final weight decay value.

  • nb_epochs (int) – Total number of training epochs.

  • use_poly (bool, default=False) – Use polynomial decay if True; cosine decay otherwise.

  • exponent (float, default=0.9) – Exponent controlling polynomial decay (only used if use_poly=True).

on_epoch_begin(epoch: int) None[source]

Update weight decay at the start of each epoch.

Parameters:

epoch (int) – Current epoch number (0-based).

Return type:

None