EffUNet

This is the encoder-decoder that use EfficientNet3D encoder and a VGGDecoder.

3D efficient net stolen from: https://github.com/shijianjian/EfficientNet-PyTorch-3D.

Usage:

model = EfficientNet3D.from_name("efficientnet-b1", override_params={'include_top': False}, in_channels=1)
model.cuda() # On CUDA machine
model.to('mps') # On Apple Silicon
class biom3d.models.unet3d_eff.EffUNet(*args: Any, **kwargs: Any)[source]

3D U-Net model using EfficientNet3D as encoder and VGG-style decoder.

This model builds a pyramid of intermediate feature maps from the encoder, and passes them to the decoder for semantic segmentation.

Variables:
  • encoder (EfficientNet3D) – EfficientNet3D encoder model.

  • pyramid (list) – List of intermediate encoder layers used for skip connections.

  • down (dict) – Dictionary mapping pyramid levels to encoder activations (populated via forward hooks).

  • decoder (torch.nn.Module) – VGG-style decoder module.

__init__(patch_size: int | tuple[int], num_pools: list[int] = [5, 5, 5], num_classes: int = 1, factor: int = 32, encoder_ckpt: str | None = None, model_ckpt: str | None = None, use_deep: bool = True, in_planes: int = 1)[source]

3D U-Net model using EfficientNet3D as encoder and VGG-style decoder.

This model builds a pyramid of intermediate feature maps from the encoder, and passes them to the decoder for semantic segmentation.

Parameters:
  • patch_size (tuple of int or int) – Shape of the input patch (D, H, W). The encoder will alwayse use an int but the config will always send a tuple ¯_(ツ)_/¯.

  • num_pools (list of int, default=[5, 5, 5]) – Number of pooling steps per spatial dimension.

  • num_classes (int, default=1) – Number of output segmentation classes.

  • factor (int, default=32) – Base scaling factor for the decoder channels.

  • encoder_ckpt (str or None, optional) – Path to a pretrained encoder checkpoint.

  • model_ckpt (str or None, optional) – Path to a full model checkpoint.

  • use_deep (bool, default=True) – Whether to use deep supervision in the decoder.

  • in_planes (int, default=1) – Number of input channels.

forward(x: torch.Tensor) torch.Tensor[source]

Forward pass of the model.

Parameters:

x (torch.Tensor) – Input tensor of shape (N, C, D, H, W).

Returns:

Output segmentation map of shape (N, num_classes, D, H, W).

Return type:

torch.Tensor

freeze_encoder(freeze: bool = True) None[source]

Freeze or unfreeze the encoder’s weights.

Parameters:

freeze (bool, optional) – If True, disables gradient computation for encoder parameters.

Return type:

None

get_activation(name: str) Callable[source]

Create a forward hook for capturing activations.

Parameters:

name (int) – Index of the pyramid level to assign the activation to.

Returns:

A forward hook function.

Return type:

function

unfreeze_encoder() None[source]

Shortcut for unfreezing the encoder.

biom3d.models.unet3d_eff.get_layer(model: torch.nn.Module, layer_names: list[str]) torch.nn.Module[source]

Retrieve a submodule from a model based on a list of keys.

Parameters:
  • model (nn.Module) – The PyTorch model to search within.

  • layer_names (list of str) – List of submodule names to traverse, e.g. [‘_blocks’, ‘0’, ‘_depthwise_conv’].

Returns:

The requested submodule.

Return type:

nn.Module

biom3d.models.unet3d_eff.get_outfmaps(layer: torch.nn.Module) int[source]

Returns the depth of output feature maps of a layer.

Parameters:

layer (nn.Module) – The layer to inspect.

Returns:

Number of output feature maps (channels).

Return type:

int

Notes

Tries to read from ‘num_features’ or ‘in_channels’ attributes. Returns 0 on failure.

biom3d.models.unet3d_eff.get_pyramid(model: torch.nn.Module, pyramid: dict) list[torch.nn.Module][source]

Retrieves multiple submodules from a model according to a dictionary of paths.

Parameters:
  • model (nn.Module) – The model to extract layers from.

  • pyramid (dict) – Dictionary where each value is a list of strings indicating a submodule path.

Examples

>>> pyramid = {
...     0: ['_conv_stem'],              # 100
...     1: ['_blocks', '1', '_bn0'],    # 50
...     2: ['_blocks', '3', '_bn0'],    # 25
...     3: ['_blocks', '5', '_bn0'],    # 12
...     4: ['_blocks', '11', '_bn0'],   # 6
...     5: ['_bn1']                     # 3
... }
Returns:

List of layers (submodules) extracted from the model.

Return type:

list of nn.Module