VGG3dDeep¶
This is the encoder-decoder that use VGGEncoder encoder and a VGGDecoder. It is an adaptation nnUnet.
Biom3d adaptation of nnUnet base model.
- class biom3d.models.unet3d_vgg_deep.UNet(*args: Any, **kwargs: Any)[source]¶
A 3D UNet architecture utilizing VGG-style encoder and decoder blocks for volumetric (3D) image segmentation.
The UNet model is a convolutional neural network for fast and precise segmentation of images. This implementation incorporates VGG blocks for encoding and decoding, allowing for deep feature extraction and reconstruction, respectively. The model supports dynamic adjustment of pooling layers and class numbers, along with optional deep decoder usage and weight initialization from pre-trained checkpoints.
- Variables:
encoder (VGGEncoder) – The encoder part of the UNet, responsible for downscaling and feature extraction.
decoder (VGGDecoder) – The decoder part of the UNet, responsible for upscaling and constructing the segmentation map.
- __init__(num_pools: list[int] = [5, 5, 5], num_classes: int = 1, factor: int = 32, encoder_ckpt: str | None = None, model_ckpt: str | None = None, use_deep: bool = True, in_planes: int = 1, flip_strides: bool = False, roll_strides: bool = True)[source]¶
Unet initialization.
- Parameters:
num_pools (list of int, default=[5,5,5]) – A list of integers defining the number of pooling layers for each dimension of the input.
num_classes (int, default=1) – The number of classes for segmentation.
factor (int, default=32) – The scaling factor for the number of channels in VGG blocks.
encoder_ckpt (str, optional) – Path to a checkpoint file from which to load encoder weights.
model_ckpt (str, optional) – Path to a checkpoint file from which to load the entire model’s weights.
use_deep (bool, default=True) – Flag to indicate whether to use a deep decoder.
in_planes (int, default=1) – The number of input channels.
flip_strides (bool, default=False) – Flag to flip strides to match encoder and decoder dimensions. Useful for ensuring dimensionality alignment.
roll_strides (bool, default=True) – Whether to roll strides when computing pooling (used for backward compatibility for models trained before commit f2ac9ee (August 2023)).
- forward(x: torch.Tensor) torch.Tensor[source]¶
Define the forward pass of the UNet model.
- Parameters:
x (torch.Tensor) – The input tensor representing the image to be segmented.
- Returns:
The output segmentation map tensor.
- Return type:
torch.Tensor
- freeze_encoder(freeze: bool = True) None[source]¶
Freeze or unfreeze the encoder’s weights based on the input flag.
- Parameters:
freeze (bool, optional) – If True, the encoder’s weights are frozen, otherwise they are unfrozen. Default is True.
- Return type:
None