EffUNet¶
This is the encoder-decoder that use EfficientNet3D encoder and a VGGDecoder.
3D efficient net stolen from: https://github.com/shijianjian/EfficientNet-PyTorch-3D.
Usage:
model = EfficientNet3D.from_name("efficientnet-b1", override_params={'include_top': False}, in_channels=1)
model.cuda() # On CUDA machine
model.to('mps') # On Apple Silicon
- class biom3d.models.unet3d_eff.EffUNet(*args: Any, **kwargs: Any)[source]¶
3D U-Net model using EfficientNet3D as encoder and VGG-style decoder.
This model builds a pyramid of intermediate feature maps from the encoder, and passes them to the decoder for semantic segmentation.
- Variables:
encoder (EfficientNet3D) – EfficientNet3D encoder model.
pyramid (list) – List of intermediate encoder layers used for skip connections.
down (dict) – Dictionary mapping pyramid levels to encoder activations (populated via forward hooks).
decoder (torch.nn.Module) – VGG-style decoder module.
- __init__(patch_size: int | tuple[int], num_pools: list[int] = [5, 5, 5], num_classes: int = 1, factor: int = 32, encoder_ckpt: str | None = None, model_ckpt: str | None = None, use_deep: bool = True, in_planes: int = 1)[source]¶
3D U-Net model using EfficientNet3D as encoder and VGG-style decoder.
This model builds a pyramid of intermediate feature maps from the encoder, and passes them to the decoder for semantic segmentation.
- Parameters:
patch_size (tuple of int or int) – Shape of the input patch (D, H, W). The encoder will alwayse use an int but the config will always send a tuple ¯_(ツ)_/¯.
num_pools (list of int, default=[5, 5, 5]) – Number of pooling steps per spatial dimension.
num_classes (int, default=1) – Number of output segmentation classes.
factor (int, default=32) – Base scaling factor for the decoder channels.
encoder_ckpt (str or None, optional) – Path to a pretrained encoder checkpoint.
model_ckpt (str or None, optional) – Path to a full model checkpoint.
use_deep (bool, default=True) – Whether to use deep supervision in the decoder.
in_planes (int, default=1) – Number of input channels.
- forward(x: torch.Tensor) torch.Tensor[source]¶
Forward pass of the model.
- Parameters:
x (torch.Tensor) – Input tensor of shape (N, C, D, H, W).
- Returns:
Output segmentation map of shape (N, num_classes, D, H, W).
- Return type:
torch.Tensor
- freeze_encoder(freeze: bool = True) None[source]¶
Freeze or unfreeze the encoder’s weights.
- Parameters:
freeze (bool, optional) – If True, disables gradient computation for encoder parameters.
- Return type:
None
- biom3d.models.unet3d_eff.get_layer(model: torch.nn.Module, layer_names: list[str]) torch.nn.Module[source]¶
Retrieve a submodule from a model based on a list of keys.
- Parameters:
model (nn.Module) – The PyTorch model to search within.
layer_names (list of str) – List of submodule names to traverse, e.g. [‘_blocks’, ‘0’, ‘_depthwise_conv’].
- Returns:
The requested submodule.
- Return type:
nn.Module
- biom3d.models.unet3d_eff.get_outfmaps(layer: torch.nn.Module) int[source]¶
Returns the depth of output feature maps of a layer.
- Parameters:
layer (nn.Module) – The layer to inspect.
- Returns:
Number of output feature maps (channels).
- Return type:
int
Notes
Tries to read from ‘num_features’ or ‘in_channels’ attributes. Returns 0 on failure.
- biom3d.models.unet3d_eff.get_pyramid(model: torch.nn.Module, pyramid: dict) list[torch.nn.Module][source]¶
Retrieves multiple submodules from a model according to a dictionary of paths.
- Parameters:
model (nn.Module) – The model to extract layers from.
pyramid (dict) – Dictionary where each value is a list of strings indicating a submodule path.
Examples
>>> pyramid = { ... 0: ['_conv_stem'], # 100 ... 1: ['_blocks', '1', '_bn0'], # 50 ... 2: ['_blocks', '3', '_bn0'], # 25 ... 3: ['_blocks', '5', '_bn0'], # 12 ... 4: ['_blocks', '11', '_bn0'], # 6 ... 5: ['_bn1'] # 3 ... }
- Returns:
List of layers (submodules) extracted from the model.
- Return type:
list of nn.Module