VGG3D Decoder

This decoder is based on nnUnet decoder, and is currently used by EffUNet and UNet

3D VGG decoder, with deep supervision (each decoder level has an output).

class biom3d.models.decoder_vgg_deep.DecoderBlock(*args: Any, **kwargs: Any)[source]

A decoder block consisting of an upsampling operation followed by an encoder block.

This block upsamples the lower-resolution feature map, concatenates it with a skip connection from the encoder, and processes the result through a residual encoder block.

Variables:
  • expansion (int) – Expansion factor of the encoder block (default is 1).

  • up (nn.ConvTranspose3d) – Transposed convolution used for upsampling.

  • encoder_block (nn.Module) – Encoder block applied after concatenation.

__init__(block: type[torch.nn.Module], in_planes_low: int, in_planes_high: int, planes: int, stride: list[int])[source]

Initialize a decoder block consisting of an upsampling operation followed by an encoder block.

Parameters:
  • block (type[nn.Module]) – Class of the residual block used to build encoder/decoder layers (e.g., EncoderBlock).

  • in_planes_low (int) – Number of input channels from the low-resolution feature map.

  • in_planes_high (int) – Number of channels from the high-resolution skip connection.

  • planes (int) – Number of output channels after the encoder block.

  • stride (list of int) – Stride for the transposed convolution (upsampling factor).

forward(x: list[torch.Tensor]) torch.Tensor[source]

Forward pass of the decoder block.

Parameters:

x (list of torch.Tensor) – A pair [low_res, high_res] of feature maps to be merged.

Returns:

Output of the encoder block after upsampling and concatenation.

Return type:

torch.Tensor

class biom3d.models.decoder_vgg_deep.VGGDecoder(*args: Any, **kwargs: Any)[source]

A VGG-style decoder with optional deep supervision and intermediate embeddings.

This decoder reconstructs feature maps by progressively upsampling and fusing skip connections from an encoder. It supports multi-scale supervision and embedding output.

Variables:
  • use_deep (bool) – If True, enable deep supervision (multi-level outputs).

  • use_emb (bool) – If True, only return the intermediate embedding from the third decoder stage.

  • strides (list[list[int]]) – List of strides (upsampling factors) per decoder stage.

  • layers (nn.ModuleList) – List of DecoderBlocks composing the decoder.

  • convs (nn.ModuleList) – List of 1×1 convolutions applied after each decoder stage (for supervision).

__init__(block: type[torch.nn.Module], num_pools: list[int], factor_e: int | list[int] = 32, factor_d: int | list[int] = 32, flip_strides: bool = False, num_classes: int = 1, use_deep: bool = True, use_emb: bool = False, roll_strides: bool = True)[source]

Initialize the decoder architecture.

Parameters:
  • block (type[nn.Module]) – Class of the residual block used to build encoder/decoder layers (e.g., EncoderBlock).

  • num_pools (list of int) – Number of pooling operations at each encoder stage.

  • factor_e (int or list of int, default=32) – Base or per-layer depth factor for encoder feature maps.

  • factor_d (int or list of int, default=32) – Base or per-layer depth factor for decoder feature maps.

  • flip_strides (bool, default=False) – Whether to reverse the order of upsampling strides. Flipped strides creates larger feature maps.

  • num_classes (int, default=1) – Number of output channels (e.g. segmentation classes).

  • use_deep (bool, default=True) – If True, enables deep supervision at multiple decoder levels.

  • use_emb (bool, default=False) – If True, return the third decoder output as an embedding.

  • roll_strides (bool, default=True) – Legacy support for reversing encoder stride order (for older models, before commit f2ac9ee (August 2023)).

forward(x: list[torch.Tensor])[source]

Forward pass through the decoder.

Parameters:

x (list of torch.Tensor) – List of feature maps from the encoder (length depends on number of stages, but it should be 6).

Returns:

out – Final prediction map or list of maps (if deep supervision is enabled). If use_emb is True, returns only the intermediate embedding tensor.

Return type:

torch.Tensor or list of torch.Tensor