VGG3D Decoder¶
This decoder is based on nnUnet decoder, and is currently used by EffUNet and UNet
3D VGG decoder, with deep supervision (each decoder level has an output).
- class biom3d.models.decoder_vgg_deep.DecoderBlock(*args: Any, **kwargs: Any)[source]¶
A decoder block consisting of an upsampling operation followed by an encoder block.
This block upsamples the lower-resolution feature map, concatenates it with a skip connection from the encoder, and processes the result through a residual encoder block.
- Variables:
expansion (int) – Expansion factor of the encoder block (default is 1).
up (nn.ConvTranspose3d) – Transposed convolution used for upsampling.
encoder_block (nn.Module) – Encoder block applied after concatenation.
- __init__(block: type[torch.nn.Module], in_planes_low: int, in_planes_high: int, planes: int, stride: list[int])[source]¶
Initialize a decoder block consisting of an upsampling operation followed by an encoder block.
- Parameters:
block (type[nn.Module]) – Class of the residual block used to build encoder/decoder layers (e.g., EncoderBlock).
in_planes_low (int) – Number of input channels from the low-resolution feature map.
in_planes_high (int) – Number of channels from the high-resolution skip connection.
planes (int) – Number of output channels after the encoder block.
stride (list of int) – Stride for the transposed convolution (upsampling factor).
- class biom3d.models.decoder_vgg_deep.VGGDecoder(*args: Any, **kwargs: Any)[source]¶
A VGG-style decoder with optional deep supervision and intermediate embeddings.
This decoder reconstructs feature maps by progressively upsampling and fusing skip connections from an encoder. It supports multi-scale supervision and embedding output.
- Variables:
use_deep (bool) – If True, enable deep supervision (multi-level outputs).
use_emb (bool) – If True, only return the intermediate embedding from the third decoder stage.
strides (list[list[int]]) – List of strides (upsampling factors) per decoder stage.
layers (nn.ModuleList) – List of DecoderBlocks composing the decoder.
convs (nn.ModuleList) – List of 1×1 convolutions applied after each decoder stage (for supervision).
- __init__(block: type[torch.nn.Module], num_pools: list[int], factor_e: int | list[int] = 32, factor_d: int | list[int] = 32, flip_strides: bool = False, num_classes: int = 1, use_deep: bool = True, use_emb: bool = False, roll_strides: bool = True)[source]¶
Initialize the decoder architecture.
- Parameters:
block (type[nn.Module]) – Class of the residual block used to build encoder/decoder layers (e.g., EncoderBlock).
num_pools (list of int) – Number of pooling operations at each encoder stage.
factor_e (int or list of int, default=32) – Base or per-layer depth factor for encoder feature maps.
factor_d (int or list of int, default=32) – Base or per-layer depth factor for decoder feature maps.
flip_strides (bool, default=False) – Whether to reverse the order of upsampling strides. Flipped strides creates larger feature maps.
num_classes (int, default=1) – Number of output channels (e.g. segmentation classes).
use_deep (bool, default=True) – If True, enables deep supervision at multiple decoder levels.
use_emb (bool, default=False) – If True, return the third decoder output as an embedding.
roll_strides (bool, default=True) – Legacy support for reversing encoder stride order (for older models, before commit f2ac9ee (August 2023)).
- forward(x: list[torch.Tensor])[source]¶
Forward pass through the decoder.
- Parameters:
x (list of torch.Tensor) – List of feature maps from the encoder (length depends on number of stages, but it should be 6).
- Returns:
out – Final prediction map or list of maps (if deep supervision is enabled). If use_emb is True, returns only the intermediate embedding tensor.
- Return type:
torch.Tensor or list of torch.Tensor