Source code for biom3d.models.unet3d_eff

"""
3D efficient net stolen from: https://github.com/shijianjian/EfficientNet-PyTorch-3D.

Usage:

.. code-block:: python

    model = EfficientNet3D.from_name("efficientnet-b1", override_params={'include_top': False}, in_channels=1)
    model.cuda() # On CUDA machine
    model.to('mps') # On Apple Silicon

"""

from typing import Callable, Optional
from biom3d.models.encoder_vgg import EncoderBlock
from biom3d.models.decoder_vgg_deep import VGGDecoder
from biom3d.models.encoder_efficientnet3d import EfficientNet3D, efficientnet3d

import torch
from torch import nn

[docs] def get_layer(model:nn.Module, layer_names:list[str])->nn.Module: """ Retrieve a submodule from a model based on a list of keys. Parameters ---------- model : nn.Module The PyTorch model to search within. layer_names : list of str List of submodule names to traverse, e.g. ['_blocks', '0', '_depthwise_conv']. Returns ------- nn.Module The requested submodule. """ for e in layer_names: model = model._modules[e] return model
[docs] def get_pyramid(model:nn.Module, pyramid:dict)->list[nn.Module]: """ Retrieves multiple submodules from a model according to a dictionary of paths. Parameters ---------- model : nn.Module The model to extract layers from. pyramid : dict Dictionary where each value is a list of strings indicating a submodule path. Examples -------- >>> pyramid = { ... 0: ['_conv_stem'], # 100 ... 1: ['_blocks', '1', '_bn0'], # 50 ... 2: ['_blocks', '3', '_bn0'], # 25 ... 3: ['_blocks', '5', '_bn0'], # 12 ... 4: ['_blocks', '11', '_bn0'], # 6 ... 5: ['_bn1'] # 3 ... } Returns ------- list of nn.Module List of layers (submodules) extracted from the model. """ layers = [] for v in pyramid.values(): layers += [get_layer(model, v)] return layers
[docs] def get_outfmaps(layer:nn.Module)->int: """ Returns the depth of output feature maps of a layer. Parameters ---------- layer : nn.Module The layer to inspect. Returns ------- int Number of output feature maps (channels). Notes ----- Tries to read from 'num_features' or 'in_channels' attributes. Returns 0 on failure. """ if 'num_features' in layer.__dict__.keys(): return layer.num_features elif 'in_channels' in layer.__dict__.keys(): return layer.in_channels else: print("[Error] layer is not standard, cannot extract output feature maps.") return 0
#--------------------------------------------------------------------------- # 3D UNet with the previous encoder and decoder
[docs] class EffUNet(nn.Module): """ 3D U-Net model using EfficientNet3D as encoder and VGG-style decoder. This model builds a pyramid of intermediate feature maps from the encoder, and passes them to the decoder for semantic segmentation. :ivar EfficientNet3D encoder: EfficientNet3D encoder model. :ivar list pyramid: List of intermediate encoder layers used for skip connections. :ivar dict down: Dictionary mapping pyramid levels to encoder activations (populated via forward hooks). :ivar torch.nn.Module decoder: VGG-style decoder module. """
[docs] def __init__( self, patch_size:int|tuple[int], # TODO: Clement: Guillaume this should be a tuple (or something like it) but the whole code of the encoder is considering it as an int, we need to make it clear num_pools:list[int]=[5,5,5], num_classes:int=1, factor:int=32, encoder_ckpt:Optional[str] = None, model_ckpt:Optional[str] = None, use_deep:bool=True, in_planes:int = 1, ): """ 3D U-Net model using EfficientNet3D as encoder and VGG-style decoder. This model builds a pyramid of intermediate feature maps from the encoder, and passes them to the decoder for semantic segmentation. Parameters ---------- patch_size : tuple of int or int Shape of the input patch (D, H, W). The encoder will alwayse use an int but the config will always send a tuple ¯\\_(ツ)_/¯. num_pools : list of int, default=[5, 5, 5] Number of pooling steps per spatial dimension. num_classes : int, default=1 Number of output segmentation classes. factor : int, default=32 Base scaling factor for the decoder channels. encoder_ckpt : str or None, optional Path to a pretrained encoder checkpoint. model_ckpt : str or None, optional Path to a full model checkpoint. use_deep : bool, default=True Whether to use deep supervision in the decoder. in_planes : int, default=1 Number of input channels. """ super(EffUNet, self).__init__() pyramid={ # efficientnet b4 0: ['_bn0'], 1: ['_blocks', '1', '_bn2'], 2: ['_blocks', '5', '_bn2'], 3: ['_blocks', '9', '_bn2'], 4: ['_blocks', '21', '_bn2'], 5: ['_blocks', '31', '_bn2'], } blocks_args, global_params = efficientnet3d( width_coefficient=1.4, # efficientnet b4 depth_coefficient=1.8, dropout_rate=0.4, drop_connect_rate=0.2, image_size=patch_size, include_top=False ) self.encoder = EfficientNet3D( blocks_args, global_params, in_channels=in_planes, num_pools=num_pools, ) # load encoder if needed if encoder_ckpt is not None: print("Load encoder weights from", encoder_ckpt) if torch.cuda.is_available(): self.encoder.cuda() elif torch.backends.mps.is_available(): self.encoder.to('mps') ckpt = torch.load(encoder_ckpt) if 'model' in ckpt.keys(): # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()} # remove `0.` prefix induced by the sequential wrapper state_dict = {k.replace("0.layers", "layers"): v for k, v in state_dict.items()} print(self.encoder.load_state_dict(state_dict, strict=False)) elif 'teacher' in ckpt.keys(): # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in ckpt['teacher'].items()} # remove `backbone.` prefix induced by multicrop wrapper state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} print(self.encoder.load_state_dict(state_dict, strict=False)) else: print("[Warning] the following encoder couldn't be loaded, wrong key:", encoder_ckpt) self.pyramid = get_pyramid(self.encoder, pyramid) # only the first five elements of the list are used # hook the pyramid self.down = {} for i in range(len(self.pyramid)): self.pyramid[i].register_forward_hook(self.get_activation(i)) self.decoder = VGGDecoder( EncoderBlock, # SmallEncoderBlock, num_pools=num_pools, num_classes=num_classes, factor_e=[get_outfmaps(l) for l in self.pyramid][::-1], factor_d=[get_outfmaps(l)*2 for l in self.pyramid][::-1][1:-1]+[factor], use_deep=use_deep, ) if model_ckpt is not None: print("Load model weights from", model_ckpt) if torch.cuda.is_available(): self.cuda() elif torch.backends.mps.is_available(): self.to('mps') ckpt = torch.load(model_ckpt) if 'encoder.last_layer.weight' in ckpt['model'].keys(): del ckpt['model']['encoder.last_layer.weight'] self.load_state_dict(ckpt['model'])
[docs] def freeze_encoder(self, freeze:bool=True)->None: """ Freeze or unfreeze the encoder's weights. Parameters ---------- freeze : bool, optional If True, disables gradient computation for encoder parameters. Returns ------- None """ if freeze: print("Freezing encoder weights...") else: print("Unfreezing encoder weights...") for l in self.encoder.parameters(): l.requires_grad = not freeze
[docs] def unfreeze_encoder(self)->None: """Shortcut for unfreezing the encoder.""" self.freeze_encoder(False)
[docs] def get_activation(self, name:str)->Callable: """ Create a forward hook for capturing activations. Parameters ---------- name : int Index of the pyramid level to assign the activation to. Returns ------- function A forward hook function. """ def hook(model, input, output): self.down[name] = output return hook
[docs] def forward(self, x:torch.Tensor)->torch.Tensor: """ Forward pass of the model. Parameters ---------- x : torch.Tensor Input tensor of shape (N, C, D, H, W). Returns ------- torch.Tensor Output segmentation map of shape (N, num_classes, D, H, W). """ self.encoder(x) out = self.decoder(list(self.down.values())) return out
#---------------------------------------------------------------------------