Source code for biom3d.models.unet3d_vgg_deep

"""Biom3d adaptation of nnUnet base model."""

from typing import Optional
import torch
from torch import nn

from biom3d.models.encoder_vgg import EncoderBlock, VGGEncoder
from biom3d.models.decoder_vgg_deep import VGGDecoder

#---------------------------------------------------------------------------
# 3D UNet with the previous encoder and decoder

[docs] class UNet(nn.Module): """ A 3D UNet architecture utilizing VGG-style encoder and decoder blocks for volumetric (3D) image segmentation. The UNet model is a convolutional neural network for fast and precise segmentation of images. This implementation incorporates VGG blocks for encoding and decoding, allowing for deep feature extraction and reconstruction, respectively. The model supports dynamic adjustment of pooling layers and class numbers, along with optional deep decoder usage and weight initialization from pre-trained checkpoints. :ivar VGGEncoder encoder: The encoder part of the UNet, responsible for downscaling and feature extraction. :ivar VGGDecoder decoder: The decoder part of the UNet, responsible for upscaling and constructing the segmentation map. """
[docs] def __init__( self, num_pools:list[int]=[5,5,5], num_classes:int=1, factor:int=32, encoder_ckpt:Optional[str] = None, model_ckpt:Optional[str] = None, use_deep:bool=True, in_planes:int = 1, flip_strides:bool = False, roll_strides:bool = True, #used for models trained before commit f2ac9ee (August 2023) ): """ Unet initialization. Parameters ---------- num_pools : list of int, default=[5,5,5] A list of integers defining the number of pooling layers for each dimension of the input. num_classes : int, default=1 The number of classes for segmentation. factor : int, default=32 The scaling factor for the number of channels in VGG blocks. encoder_ckpt : str, optional Path to a checkpoint file from which to load encoder weights. model_ckpt : str, optional Path to a checkpoint file from which to load the entire model's weights. use_deep : bool, default=True Flag to indicate whether to use a deep decoder. in_planes : int, default=1 The number of input channels. flip_strides : bool, default=False Flag to flip strides to match encoder and decoder dimensions. Useful for ensuring dimensionality alignment. roll_strides : bool, default=True Whether to roll strides when computing pooling (used for backward compatibility for models trained before commit f2ac9ee (August 2023)). """ super(UNet, self).__init__() self.encoder = VGGEncoder( EncoderBlock, num_pools=num_pools, factor=factor, in_planes=in_planes, flip_strides=flip_strides, roll_strides=roll_strides, ) self.decoder = VGGDecoder( EncoderBlock, num_pools=num_pools, num_classes=num_classes, factor_e=factor, factor_d=factor, use_deep=use_deep, flip_strides=flip_strides, roll_strides=roll_strides, ) # load encoder if needed if encoder_ckpt is not None: print("Load encoder weights from", encoder_ckpt) if torch.cuda.is_available(): self.encoder.cuda() elif torch.backends.mps.is_available(): self.encoder.to('mps') ckpt = torch.load(encoder_ckpt) if 'model' in ckpt.keys(): # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()} # remove `0.` prefix induced by the sequential wrapper state_dict = {k.replace("0.layers", "layers"): v for k, v in state_dict.items()} # remove `backbone.` prefix induced by pretraining state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} print(self.encoder.load_state_dict(state_dict, strict=False)) elif 'teacher' in ckpt.keys(): # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in ckpt['teacher'].items()} # remove `backbone.` prefix induced by multicrop wrapper state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} print(self.encoder.load_state_dict(state_dict, strict=False)) else: print("[Warning] the following encoder couldn't be loaded, wrong key:", encoder_ckpt) if model_ckpt is not None: self.load(model_ckpt)
[docs] def freeze_encoder(self, freeze:bool=True)->None: """ Freeze or unfreeze the encoder's weights based on the input flag. Parameters ---------- freeze : bool, optional If True, the encoder's weights are frozen, otherwise they are unfrozen. Default is True. Returns ------- None """ if freeze: print("Freezing encoder weights...") else: print("Unfreezing encoder weights...") for l in self.encoder.parameters(): l.requires_grad = not freeze
[docs] def unfreeze_encoder(self)->None: """Unfreeze the encoder's weights. Convenience method calling `freeze_encoder` with `False`.""" self.freeze_encoder(False)
[docs] def load(self, model_ckpt:str)->None: """ Load the model from checkpoint. The checkpoint dictionary must have a 'model' key with the saved model for value. Parameters ---------- model_ckpt : str The path to the checkpoint file containing the model's weights. Returns ------- None """ print("Load model weights from", model_ckpt) if torch.cuda.is_available(): self.cuda() device = torch.device('cuda') elif torch.backends.mps.is_available(): self.to('mps') device = torch.device('mps') else: self.cpu() device = torch.device('cpu') ckpt = torch.load(model_ckpt, map_location=device) if 'encoder.last_layer.weight' in ckpt['model'].keys(): del ckpt['model']['encoder.last_layer.weight'] print(self.load_state_dict(ckpt['model'], strict=False))
[docs] def forward(self, x:torch.Tensor)->torch.Tensor: """ Define the forward pass of the UNet model. Parameters ---------- x : torch.Tensor The input tensor representing the image to be segmented. Returns ------- torch.Tensor The output segmentation map tensor. """ out = self.encoder(x) out = self.decoder(out) return out
#---------------------------------------------------------------------------