"""3D VGG decoder, with deep supervision (each decoder level has an output)."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
from biom3d.utils import convert_num_pools
#---------------------------------------------------------------------------
# 3D Resnet decoder
def _weights_init(m:nn.Module)->None:
"""
Initialize weights of convolutional and linear layers using Kaiming normal initialization.
Parameters
----------
m : nn.Module
A PyTorch module. If it's an instance of `nn.Conv3d` or `nn.Linear`, its weights will be initialized.
Returns
-------
None
"""
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight)
[docs]
class DecoderBlock(nn.Module):
"""
A decoder block consisting of an upsampling operation followed by an encoder block.
This block upsamples the lower-resolution feature map, concatenates it with a skip connection
from the encoder, and processes the result through a residual encoder block.
:ivar int expansion: Expansion factor of the encoder block (default is 1).
:ivar nn.ConvTranspose3d up: Transposed convolution used for upsampling.
:ivar nn.Module encoder_block: Encoder block applied after concatenation.
"""
expansion:int = 1
up: nn.ConvTranspose3d
encoder_block:nn.Module
[docs]
def __init__(self,
block:type[nn.Module],
in_planes_low:int, # in_planes is the depth size after the concatenation
in_planes_high:int, # in_planes is the depth size after the concatenation
planes:int, # the depth size of the output
stride:list[int], # for the upconv
# option='A'# option for the upsampling (A: use upsamble; B: use convtranspose NOT IMPLEMENTED!)
):
"""
Initialize a decoder block consisting of an upsampling operation followed by an encoder block.
Parameters
----------
block : type[nn.Module]
Class of the residual block used to build encoder/decoder layers (e.g., EncoderBlock).
in_planes_low : int
Number of input channels from the low-resolution feature map.
in_planes_high : int
Number of channels from the high-resolution skip connection.
planes : int
Number of output channels after the encoder block.
stride : list of int
Stride for the transposed convolution (upsampling factor).
"""
super(DecoderBlock, self).__init__()
# if option == 'A':
# self.up = nn.Upsample(scale_factor=2, mode='trilinear') # use bilinear but can be changed...
# elif option == 'B':
self.up = nn.ConvTranspose3d(in_planes_low, in_planes_high, kernel_size=stride, stride=stride, bias=False)
self.encoder_block = block(
in_planes=in_planes_high*2,
planes=planes,
stride=1,
)
[docs]
def forward(self, x:list[torch.Tensor])->torch.Tensor: # x is a list of two inputs [low_res, high_res]
"""
Forward pass of the decoder block.
Parameters
----------
x : list of torch.Tensor
A pair [low_res, high_res] of feature maps to be merged.
Returns
-------
torch.Tensor
Output of the encoder block after upsampling and concatenation.
"""
low, high = x
low = self.up(low)
out = torch.cat([low,high],dim=1)
out = self.encoder_block(out)
return out
[docs]
class VGGDecoder(nn.Module):
"""
A VGG-style decoder with optional deep supervision and intermediate embeddings.
This decoder reconstructs feature maps by progressively upsampling and fusing skip connections
from an encoder. It supports multi-scale supervision and embedding output.
:ivar bool use_deep: If True, enable deep supervision (multi-level outputs).
:ivar bool use_emb: If True, only return the intermediate embedding from the third decoder stage.
:ivar list[list[int]] strides: List of strides (upsampling factors) per decoder stage.
:ivar nn.ModuleList layers: List of DecoderBlocks composing the decoder.
:ivar nn.ModuleList convs: List of 1×1 convolutions applied after each decoder stage (for supervision).
"""
[docs]
def __init__(
self,
block:type[nn.Module],
num_pools:list[int],
factor_e:int|list[int] = 32,
factor_d:int|list[int] = 32,
flip_strides:bool = False,
num_classes:int = 1,
use_deep:bool=True,
use_emb:bool=False,
roll_strides:bool = True,
):
"""
Initialize the decoder architecture.
Parameters
----------
block : type[nn.Module]
Class of the residual block used to build encoder/decoder layers (e.g., EncoderBlock).
num_pools : list of int
Number of pooling operations at each encoder stage.
factor_e : int or list of int, default=32
Base or per-layer depth factor for encoder feature maps.
factor_d : int or list of int, default=32
Base or per-layer depth factor for decoder feature maps.
flip_strides : bool, default=False
Whether to reverse the order of upsampling strides. Flipped strides creates larger feature maps.
num_classes : int, default=1
Number of output channels (e.g. segmentation classes).
use_deep : bool, default=True
If True, enables deep supervision at multiple decoder levels.
use_emb : bool, default=False
If True, return the third decoder output as an embedding.
roll_strides : bool, default=True
Legacy support for reversing encoder stride order (for older models, before commit f2ac9ee (August 2023)).
"""
super(VGGDecoder, self).__init__()
self.use_deep = use_deep
self.use_emb = use_emb
# encoder pyramid planes/feature maps
max_num_pools = max(num_pools)+1
if isinstance(factor_e,int):
in_planes = [factor_e * i for i in [10,10,8,4,2,1]][-max_num_pools:]
elif isinstance(factor_e,list):
in_planes = factor_e[-max_num_pools:]
else:
print("[Error] factor_e has the wrong type {}".format(type(factor_e)))
in_planes_high = in_planes[1:]
# decoder planes/feature maps
if isinstance(factor_d,int):
planes = [factor_d * i for i in [10,8,4,2,1]][-max_num_pools+1:]
elif isinstance(factor_d,list):
planes = factor_d
else:
print("[Error] factor_d has the wrong type {}".format(type(factor_d)))
in_planes_low = [in_planes[0]]+planes[:-1]
# computes the strides for the scale factors
self.strides = convert_num_pools(num_pools=num_pools,roll_strides=roll_strides)
# if the encoder strides are flipped, the decoder strides are not
if not flip_strides: self.strides = np.flip(self.strides, axis=0).tolist()
# layer definition
self.layers = []
self.convs = []
for i in range(max_num_pools-1):
self.layers += [self._make_layer(
block,
in_planes_low=in_planes_low[i],
in_planes_high=in_planes_high[i],
planes=planes[i],
stride=self.strides[i],
num_blocks=1,
)]
self.convs += [nn.Conv3d(planes[i], num_classes, kernel_size=1, stride=1, padding=0, bias=False)]
# the lines below are required to register the module parameters
self.layers = nn.ModuleList(self.layers)
self.convs = nn.ModuleList(self.convs)
self.apply(_weights_init)
def _make_layer(self,
block:type[nn.Module],
in_planes_low:int,
in_planes_high:int,
planes:int,
stride:list[int],
num_blocks:int
)->nn.Sequential:
"""
Create a sequential layer composed of a DecoderBlock followed by encoder blocks.
This function builds a composite decoder stage. It first upsamples and fuses encoder features
using `DecoderBlock`, then adds more encoder blocks..
Parameters
----------
block : type[nn.Module]
Class of the encoder-style residual block used after the initial DecoderBlock.
in_planes_low : int
Number of input channels from the lower-resolution feature map.
in_planes_high : int
Number of input channels from the higher-resolution feature map (skip connection).
planes : int
Number of output channels after the decoder stage.
stride : list of int
Stride (upsampling factor) for the transposed convolution.
num_blocks : int
Number of residual blocks to apply in total (≥1). The first is wrapped in a DecoderBlock.
Returns
-------
nn.Sequential
A sequential module containing the DecoderBlock and additional residual blocks.
"""
layers = []
layers.append(DecoderBlock(block, in_planes_low, in_planes_high, planes, stride))
for _ in range(num_blocks-1):
layers.append(block(planes, planes, stride=1))
return nn.Sequential(*layers)
[docs]
def forward(self, x:list[torch.Tensor]):
"""
Forward pass through the decoder.
Parameters
----------
x : list of torch.Tensor
List of feature maps from the encoder (length depends on number of stages, but it should be 6).
Returns
-------
out: torch.Tensor or list of torch.Tensor
Final prediction map or list of maps (if deep supervision is enabled).
If `use_emb` is True, returns only the intermediate embedding tensor.
"""
deep_out = []
for i in range(len(self.layers)):
inputs = x[-1] if i==0 else out
out = self.layers[i]([inputs, x[-2-i]])
if i>=2 and self.use_deep: # deep supervision
tmp = self.convs[i](out)
deep_out += [F.interpolate(tmp, size=x[0].shape[2:], mode='trilinear')]
# if i is the antipenultimate layer
elif i==(len(self.layers)-3) and self.use_emb:
return self.convs[i](out)
out = self.convs[-1](out)
if self.use_deep: out = deep_out+[out]
return out
#---------------------------------------------------------------------------