Source code for biom3d.datasets.semseg_batchgen
"""Dataloader with batch_generator. Follow the nnUNet augmentation pipeline."""
import random
import numpy as np
import pandas as pd
from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
ContrastAugmentationTransform, GammaTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
from batchgenerators.augmentations.utils import resize_segmentation
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from typing import Any, Hashable, Optional, Iterable
from biom3d.utils import DataHandlerFactory, DataHandler, get_folds_train_test_df
#---------------------------------------------------------------------------
# random crop and pad with batchgenerator
[docs]
def centered_crop(img:np.ndarray,
msk:np.ndarray,
center:Iterable[int],
crop_shape:Iterable[int],
margin:Iterable[float]=np.zeros(3),
)->tuple[np.ndarray,np.ndarray]:
"""
Do a crop, forcing the location voxel to be located in the center of the crop.
Parameters
----------
img: numpy.ndarray
Image data.
msk: numpy.ndarray
Mask data.
center: iterable of int
Center voxel location for cropping.
crop_shape: iterable of int
Shape of the crop.
margin: iterable of float, default=np.zeros(3)
Margin around the center location.
Returns
-------
crop_img : numpy.ndarray
The cropped image, centered around center.
crop_msk: numpy.ndarray
The cropped mask, centered around center.
"""
center = np.array(center)
crop_shape = np.array(crop_shape)
margin = np.array(margin)
# middle of the crop
start = np.maximum(0,center-crop_shape//2+margin).astype(int)
# assert that the end will not be out of the crop
end = crop_shape+start
idx = [slice(0,img.shape[0])]+[slice(s[0], s[1]) for s in zip(start, end)]
idx = tuple(idx)
crop_img = img[idx]
crop_msk = msk[idx]
return crop_img, crop_msk
[docs]
def located_crop(img:np.ndarray,
msk:np.ndarray,
location:Iterable[int],
crop_shape:Iterable[int],
margin:Iterable[float]=np.zeros(3),
)->tuple[np.ndarray,np.ndarray]:
"""Do a crop, forcing the location voxel to be located in the crop.
Parameters
----------
img : numpy.ndarray
Image data.
msk : numpy.ndarray
Mask data.
location : iterable of int
Specific voxel location to include in the crop.
crop_shape : iterable of int
Shape of the crop.
margin : iterable of float, default=np.zeros(3)
Margin around the location.
Returns
-------
crop_img : numpy.ndarray
Cropped image data, containing the specified location voxel within the crop.
crop_msk : numpy.ndarray
Cropped mask data, corresponding to the cropped image region.
"""
img_shape = np.array(img.shape)[1:]
location = np.array(location)
crop_shape = np.array(crop_shape)
margin = np.array(margin)
lower_bound = np.maximum(0,location-crop_shape+margin)
higher_bound = np.maximum(lower_bound+1,np.minimum(location-margin, img_shape-crop_shape))
start = np.random.randint(low=lower_bound, high=np.maximum(lower_bound+1,higher_bound))
end = start+crop_shape
idx = [slice(0,img.shape[0])]+[slice(s[0], s[1]) for s in zip(start, end)]
idx = tuple(idx)
crop_img = img[idx]
crop_msk = msk[idx]
return crop_img, crop_msk
[docs]
def foreground_crop(img:np.ndarray,
msk:np.ndarray,
final_size:Iterable[int],
fg_margin:Iterable[float],
fg:Optional[dict[int,np.ndarray]]=None,
use_softmax:bool=True,
)->tuple[np.ndarray,np.ndarray]:
"""Do a foreground crop.
Parameters
----------
img : numpy.ndarray
Image data.
msk : numpy.ndarray
Mask data.
final_size : iterable of int
Final size of the cropped image and mask.
fg_margin : iterable of float
Margin around the foreground location.
fg : dict of int to numpy.ndarray, optional
Foreground information.
use_softmax : bool, default=True
If True, assumes softmax activation.
Returns
-------
img : numpy.ndarray
Cropped image data, focused on the foreground region.
msk : numpy.ndarray
Cropped mask data, corresponding to the cropped image region.
"""
if fg is not None:
locations = fg[random.choice(list(fg.keys()))]
else:
if tuple(msk.shape)[0]==1:
# then we consider that we don't have a one hot encoded label
rnd_label = random.randint(1,msk.max()+1)
locations = np.argwhere(msk[0] == rnd_label)
else:
# then we have a one hot encoded label
rnd_label = random.randint(int(use_softmax),tuple(msk.shape)[0]-1)
locations = np.argwhere(msk[rnd_label] == 1)
if np.array(locations).size==0: # bug fix when having empty arrays
img, msk = random_crop(img, msk, final_size)
else:
center=random.choice(locations) # choose a random voxel of this label
img, msk = centered_crop(img, msk, center, final_size, fg_margin)
return img, msk
[docs]
def random_crop(img:np.ndarray,
msk:np.ndarray,
crop_shape:Iterable[int]
)->tuple[np.ndarray,np.ndarray]:
"""
Randomly crop a portion of size prop of the original image size.
Parameters
----------
img : numpy.ndarray
Image data.
msk : numpy.ndarray
Mask data.
crop_shape : array_like
Shape of the crop.
Raises
------
AssertionError:
If img and crop_shape doesn't have the same number of dimensions.
Returns
-------
crop_img : numpy.ndarray
Cropped image data.
crop_msk : numpy.ndarray
Cropped mask data.
"""
img_shape = np.array(img.shape)[1:]
assert len(img_shape)==len(crop_shape),"[Error] Not the same dimensions! Image shape {}, Crop shape {}".format(img_shape, crop_shape)
start = np.random.randint(0, np.maximum(1,img_shape-crop_shape))
end = start+crop_shape
idx = [slice(0,img.shape[0])]+[slice(s[0], s[1]) for s in zip(start, end)]
idx = tuple(idx)
crop_img = img[idx]
crop_msk = msk[idx]
return crop_img, crop_msk
[docs]
def centered_pad(img:np.ndarray,
final_size:np.ndarray,
msk:Optional[np.ndarray]=None,
)->np.ndarray|tuple[np.ndarray,np.ndarray]:
"""
Centered pad an img and msk to fit the final_size.
Parameters
----------
img : numpy.ndarray
Image data.
final_size : array_like
Final size after padding.
msk : numpy.ndarray, optional
Mask data.
Returns
-------
pad_img: numpy.ndarray
Padded image.
pad_mask: numpy.ndarray, optional
Padded image
"""
final_size = np.array(final_size)
img_shape = np.array(img.shape[1:])
start = (final_size-np.array(img_shape))//2
start = start * (start > 0)
end = final_size-(img_shape+start)
end = end * (end > 0)
pad = np.append([[0,0]], np.stack((start,end),axis=1), axis=0)
pad_img = np.pad(img, pad, 'constant', constant_values=0)
if msk is not None:
pad_msk = np.pad(msk, pad, 'constant', constant_values=0)
return pad_img, pad_msk
else:
return pad_img
[docs]
def random_crop_pad(img:np.ndarray,
msk:np.ndarray,
final_size:Iterable[int],
fg_rate:float=0.33,
fg_margin:Iterable[float]=np.zeros(3),
fg:Optional[dict[str,np.ndarray]]=None,
use_softmax:bool=True,
)->tuple[np.ndarray,np.ndarray]:
"""
Random crop and pad if needed.
Parameters
----------
img : numpy.ndarray
Image data.
msk : numpy.ndarray
Mask data.
final_size : iterable of int
Final size after cropping and padding.
fg_rate : float, default=0.33
Probability of focusing the crop on the foreground.
fg_margin : iterable of float, optional
Margin around the foreground location.
fg : dict of int to numpy.ndarray, optional
Foreground information.
use_softmax : bool, default=True
If True, assumes softmax activation; otherwise sigmoid is used.
Returns
-------
img : numpy.ndarray
Cropped and padded image data.
msk : numpy.ndarray
Cropped and padded mask data.
"""
if isinstance(img,list): # then batch mode
imgs, msks = [], []
for i in range(len(img)):
img_, msk_ = random_crop_pad(img[i], msk[i], final_size)
imgs += [img_]
msks += [msk_]
return np.array(imgs), np.array(msks) # can convert to array as they should have now all the same shape
# choose if using foreground centrered or random alignement
force_fg = random.random()
if fg_rate>0 and force_fg<fg_rate:
img, msk = foreground_crop(img, msk, final_size, fg_margin, fg=fg, use_softmax=use_softmax)
else:
# or random crop
img, msk = random_crop(img, msk, final_size)
# pad if needed
if np.any(np.array(img.shape)[1:]-final_size)!=0:
img, msk = centered_pad(img=img, msk=msk, final_size=final_size)
return img, msk
[docs]
class RandomCropAndPadTransform(AbstractTransform):
"""
BatchGenerator transform for random cropping and padding.
:ivar str data_key: Key used to access data in dictionary.
:ivar str label_key: Key used to access label in dictionary.
:ivar float fg_rate: Foreground rate, probability of focusing crop on foreground.
:ivar Iterable[int] crop_size: Size of the crop.
"""
data_key: str
label_key: str
fg_rate: float
crop_size: Iterable[int]
[docs]
def __init__(self,
crop_size:Iterable[int],
fg_rate:float=0.33,
data_key:str="data",
label_key:str="seg"):
"""
Batch generator transform for random cropping and padding.
Parameters
----------
crop_size : iterable of int
Size of the crop.
fg_rate : float, default=0.33
Probability of focusing the crop on the foreground.
data_key : str, default="data"
Key for the data in the data dictionary.
label_key : str, default="seg"
Key for the label in the data dictionary.
"""
self.data_key = data_key
self.label_key = label_key
self.fg_rate = fg_rate
self.crop_size = crop_size
def __call__(self, **data_dict:dict[str,Any])->dict[str,Any]:
"""
Apply random cropping and padding transform to the data dictionary.
Parameters
----------
**data_dict : dict
Dictionary containing data arrays. Must contain keys matching
`self.data_key` and `self.label_key` which correspond to
the input data and segmentation mask respectively.
Returns
-------
data_dict: dict
The modified data dictionary with cropped and padded data and mask.
"""
data = data_dict.get(self.data_key)
seg = data_dict.get(self.label_key)
data, seg = random_crop_pad(data, seg, self.crop_size, self.fg_rate)
data_dict[self.data_key] = data
data_dict[self.label_key] = seg
return data_dict
#---------------------------------------------------------------------------
# image reader
[docs]
def imread(handler:DataHandler,
img:str,
msk:str,
loc:Optional[str]=None,
is3d:bool=True,
)->tuple[np.ndarray,np.ndarray,np.ndarray|None]:
"""
Read all data with the provided DataHandler.
Parameters
----------
handler: DataHandler
The DataHandler used to read data.
img: str
The path to the image.
msk: str
The path to the mask.
loc: str, optional
The path to the foreground. If None, no foreground will be returned.
is3d: bool, default=True
If image is in 3D
Returns
-------
img: numpy.ndarray
The image.
msk: numpy.ndarray
The mask.
fg: numpy.ndarray, optional
The foreground, or None.
"""
img,_ = handler.load(img)
msk,_ = handler.load(msk)
if loc is not None : fg,_ = handler.load(loc)
if len(img.shape) == 3 if is3d else 2:
img = np.expand_dims(img,0)
if len(msk.shape) == 3 if is3d else 2:
msk = np.expand_dims(msk,0)
assert (is3d and len(msk.shape)==4 and len(img.shape)==4) or (not is3d and len(msk.shape)==3 and len(img.shape)==3), "[Error] Your data has the wrong dimension."
return img, msk, fg if loc is not None else None
[docs]
class DataReader(AbstractTransform):
"""Read the data and add it to dictionary.
:ivar str data_key: Key used to access data in dictionary.
:ivar str label_key: Key used to access label in dictionary.
:ivar str loc_key: Key used to access foreground in dictionary.
:ivar bool is3d: If images are in 3d, not used yet.
:ivar DataHandler handler: DataHandler used to read data.
"""
data_key: str
label_key: str
loc_key: str
is3d: bool
handler: DataHandler
[docs]
def __init__(self,
handler:DataHandler,
is3d:bool=True,
data_key:str="data",
label_key:str="seg",
loc_key:str='loc'):
"""Read the data and add it to dictionary.
Parameters
----------
handler : DataHandler
DataHandler used to read data
is3d : bool
If images are in 3d, not used yet.
data_key : str
Key used to access data in dictionary.
label_key : str
Key used to access label in dictionary.
loc_key : str
Key used to access foreground in dictionary.
"""
self.is3d = is3d
self.data_key = data_key
self.label_key = label_key
self.loc_key= loc_key
self.handler=handler
def __call__(self, **data_dict:dict[str,Any])->dict[str,Any]:
"""
Add data to the data_dict.
Parameters
----------
**data_dict : dict
Dictionary containing data arrays. Must contain keys matching
`self.data_key`, `self.label_key` and `self.loc_key` which correspond to
the input data, segmentation mask and foreground respectively.
Returns
-------
data_dict: dict
The modified data dictionary with raw data added to their keys.
"""
data = data_dict.get(self.data_key)
seg = data_dict.get(self.label_key)
loc = data_dict.get(self.loc_key)
if isinstance(data,list):
for i in range(len(data)):
data[i], seg[i],loc[i] = imread(self.handler,data[i], seg[i],loc[i])
else:
data, seg,loc = imread(self.handler,data, seg,loc)
data_dict[self.data_key] = data
if seg is not None:
data_dict[self.label_key] = seg
if loc is not None:
data_dict[self.loc_key] = loc
return data_dict
#---------------------------------------------------------------------------
# training and validation augmentations
[docs]
def get_bbox(patch_size:Iterable[int],
final_patch_size:Iterable[int],
annotated_classes_key:Hashable,
data_shape: np.ndarray,
force_fg: bool,
class_locations: Optional[dict],
overwrite_class: Optional[int| tuple[int, ...]] = None,
verbose: bool = False
)->tuple[list[int],list[int]]:
"""
Compute bounding box coordinates for cropping a patch from the data, optionally focusing on foreground regions.
Parameters
----------
patch_size : iterable of int
Desired patch size to crop (dimensions).
final_patch_size : iterable of int
Current size of the patch after any previous cropping or resizing.
annotated_classes_key : hashable
Key identifying the annotated class in `class_locations`.
data_shape : numpy.ndarray
Shape of the full data volume or image from which the patch is cropped.
force_fg : bool
If True, ensures the patch contains at least one voxel of foreground classes.
class_locations : dict or None
Dictionary mapping class labels (int or tuple) to lists/arrays of voxel coordinates for that class.
Required if `force_fg` is True.
overwrite_class : int or tuple of int, optional
If set, forces the patch to focus on this class instead of randomly selected foreground class.
verbose : bool, default=False
If True, prints diagnostic messages.
Raises
------
AssertionError:
If class_locations is None and force_fg is True. Or overwrite_class not in class_locations
Returns
-------
bbox_lbs : list of int
Lower bounds (start indices) of the bounding box along each dimension.
bbox_ubs : list of int
Upper bounds (end indices) of the bounding box along each dimension.
Notes
-----
- The function calculates how much padding is needed if `final_patch_size` is smaller than `patch_size`.
- If `force_fg` is True, it attempts to center the bounding box on a randomly selected voxel of a foreground class.
- If no foreground voxel is found, it falls back to random cropping.
"""
# Force patch_size to have a get_item method
patch_size = np.array(patch_size)
# in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have
# locations for the given slice
need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
dim = len(data_shape)
for d in range(dim):
# if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
# always
if need_to_pad[d] + data_shape[d] < patch_size[d]:
need_to_pad[d] = patch_size[d] - data_shape[d]
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
lbs = [- need_to_pad[i] // 2 for i in range(dim)]
ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - patch_size[i] for i in range(dim)]
# if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
# at least one of the foreground classes in the patch
if not force_fg:
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
else:
assert class_locations is not None, 'if force_fg is set class_locations cannot be None'
if overwrite_class is not None:
assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \
'have class_locations (missing key)'
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
# class_locations keys can also be tuple
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
tmp = [i == annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp) and len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.nonzero(tmp)[0][0])
if len(eligible_classes_or_regions) == 0:
# this only happens if some image does not contain foreground voxels at all
selected_class = None
if verbose:
print('Case does not contain any foreground classes')
else:
# I hate myself. Future me aint gonna be happy to read this
# 2022_11_25: had to read it today. Wasn't too bad
# 2025_08_07, Clement : speak for yourself
selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
(overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class
voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None
if voxels_of_that_class is not None and len(voxels_of_that_class) > 0:
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
# i + 1 because we have first dimension 0!
bbox_lbs = [max(lbs[i], selected_voxel[i] - patch_size[i] // 2) for i in range(dim)]
else:
# If the image does not contain any foreground classes, we fall back to random cropping
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
bbox_ubs = [bbox_lbs[i] + patch_size[i] for i in range(dim)]
return bbox_lbs, bbox_ubs
[docs]
class nnUNetRandomCropAndPadTransform(AbstractTransform):
"""
Random cropping and padding transform for nnU-Net-style data augmentation.
Applies random crop centered around a foreground voxel with a certain probability (fg_rate),
and pads the data and label to the desired augmented crop size.
:ivar Iterable[int] aug_crop_size : Final shape after cropping and padding (target shape).
:ivar Iterable[int] crop_size : Crop size for network input (may differ from aug_crop_size).
:ivar float fg_rate : Probability of forcing the crop to focus on the foreground class.
:ivar str data_key : Key for the input data in the data dictionary.
:ivar str label_key : Key for the segmentation labels in the data dictionary.
:ivar str class_loc_key : Key for the precomputed voxel locations per class in the data dictionary.
"""
data_key:str
label_key:str
class_loc_key:str
fg_rate:float
crop_size:Iterable[int]
aug_crop_size:Iterable[int]
[docs]
def __init__(self,
aug_crop_size:Iterable[int],
crop_size:Iterable[int],
fg_rate:float=0.33,
data_key:str="data",
label_key:str="seg",
class_loc_key:str="loc",
):
"""
Random cropping and padding transform for nnU-Net-style data augmentation.
Parameters
----------
aug_crop_size : iterable of int
Final shape after cropping and padding (target shape).
crop_size : iterable of int
Crop size for network input (may differ from aug_crop_size).
fg_rate : float, default=0.33
Probability of forcing the crop to focus on the foreground class.
data_key : str, default="data"
Key for the input data in the data dictionary.
label_key : str, default="seg"
Key for the segmentation labels in the data dictionary.
class_loc_key : str, default="loc"
Key for the precomputed voxel locations per class in the data dictionary.
"""
self.data_key = data_key
self.label_key = label_key
self.class_loc_key = class_loc_key
self.fg_rate = fg_rate
self.crop_size = crop_size
self.aug_crop_size = aug_crop_size
def __call__(self, **data_dict:dict[str,Any])->dict:
"""
Apply the crop and pad transform to a batch of data.
Parameters
----------
**data_dict : dict
Dictionary containing data arrays. Must contain keys matching
`self.data_key`, `self.label_key` and `self.loc_key` which correspond to
the input data, segmentation mask and foreground respectively.
Returns
-------
**data_dict : dict
Updated data dictionary with cropped and padded arrays.
"""
data = data_dict.get(self.data_key)
seg = data_dict.get(self.label_key)
loc = data_dict.get(self.class_loc_key)
dim=len(data[0].shape[1:])
data_channel = data[0].shape[0]
seg_channel = seg[0].shape[0]
data_all = np.zeros([len(data), data_channel]+list(self.aug_crop_size), dtype=np.float32)
seg_all = np.zeros([len(seg), seg_channel]+list(self.aug_crop_size), dtype=np.int16)
for j,(d,s,l) in enumerate(zip(data,seg,loc)):
shape = np.array(d.shape[1:])
bbox_lbs, bbox_ubs = get_bbox(
final_patch_size=self.crop_size,
patch_size=self.aug_crop_size,
annotated_classes_key=list(l.keys()),
data_shape=shape,
force_fg=random.random()<self.fg_rate,
class_locations=l,
overwrite_class = None,
verbose = False
)
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
this_slice = tuple([slice(0, data_channel)] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
d = d[this_slice]
this_slice = tuple([slice(0, seg_channel)] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
s = s[this_slice]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
data_all[j] = np.pad(d, ((0, 0), *padding), 'constant', constant_values=0)
seg_all[j] = np.pad(s, ((0, 0), *padding), 'constant', constant_values=-1)
data_dict[self.data_key] = data_all
data_dict[self.label_key] = seg_all
return data_dict
[docs]
class Convert2DTo3DTransform(AbstractTransform):
"""
Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z).
:ivar list[str] | tuple[str] apply_to_keys: Key of the data dictionary to convert, default=('data','seg')
"""
apply_to_keys:list[str]| tuple[str]
[docs]
def __init__(self,
apply_to_keys: list[str]| tuple[str] = ('data', 'seg'),
):
"""
Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z).
Parameters
----------
apply_to_keys: list or tuple of str, default=('data','seg')
Key of the data dictionary to convert
"""
self.apply_to_keys = apply_to_keys
def __call__(self, **data_dict:dict[str,Any])->dict:
"""
Apply the conversion to a batch of data.
Parameters
----------
**data_dict : dict
Dictionary containing data arrays. Must contain keys matching
`self.aply_to_keys` which correspond to the input data and segmentation mask.
Raises
------
AssertionError:
If a key of apply_to_keys is not in data_dict.
Returns
-------
**data_dict : dict
Updated data dictionary with 3D transform.
"""
for k in self.apply_to_keys:
shape_key = f'orig_shape_{k}'
assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \
f'Convert2DTo3DTransform only works in tandem with ' \
f'Convert3DTo2DTransform and you probably forgot to add ' \
f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \
f'is where the missing key is generated)'
original_shape = data_dict[shape_key]
current_shape = data_dict[k].shape
data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2],
current_shape[-2], current_shape[-1]))
return data_dict
[docs]
class Convert3DTo2DTransform(AbstractTransform):
"""
Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel.
:ivar list[str] | tuple[str] apply_to_keys: Key of the data dictionary to convert, default=('data','seg')
"""
[docs]
def __init__(self, apply_to_keys: list[str]| tuple[str] = ('data', 'seg')):
"""
Transform a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel.
Parameters
----------
apply_to_keys: list or tuple of str, default=('data','seg')
Key of the data dictionary to convert
"""
self.apply_to_keys = apply_to_keys
def __call__(self, **data_dict:dict[str,Any])->dict:
"""
Apply the conversion to a batch of data.
Parameters
----------
**data_dict : dict
Dictionary containing data arrays. Must contain keys matching
`self.aply_to_keys` which correspond to the input data and segmentation mask.
Raises
------
AssertionError:
If a key of apply_to_keys is not in data_dict
Returns
-------
**data_dict : dict
Updated data dictionary with 2D transform.
"""
for k in self.apply_to_keys:
shp = data_dict[k].shape
assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.'
data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
shape_key = f'orig_shape_{k}'
assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \
f'It does that using the {shape_key} key. That key is ' \
f'already taken. Bummer.'
data_dict[shape_key] = shp
return data_dict
[docs]
class DictToTuple(AbstractTransform):
"""
Return a data and seg instead of a dictionary.
:ivar str data_key: Key for the input data in the dictionary, default="data"
:ivar str label_key: Key for the label/segmentation in the dictionary, default="seg"
"""
data_key:str
label_key:str
[docs]
def __init__(self, data_key:str="data", label_key:str="seg"):
"""
Transform that extracts `data` and `seg` from a dictionary and returns them as a tuple.
Parameters
----------
data_key : str, default="data"
Key for the input data in the dictionary.
label_key : str, default="seg"
Key for the label/segmentation in the dictionary.
"""
self.data_key = data_key
self.label_key = label_key
def __call__(self, **data_dict:dict[str,Any])->tuple[Any,Any]:
"""
Extract data_key and label_key from a dictionary and returns them as a tuple.
Parameters
----------
**data_dict : dict
Dictionary containing data arrays. Must contain keys matching
`self.data_key` and `self.label_key` which correspond to the input data and segmentation mask.
Returns
-------
data
Dictionary entry for data_key
seg
Dictionary entry for label_key
"""
data = data_dict.get(self.data_key)
seg = data_dict.get(self.label_key)
if isinstance(seg,list): seg = seg[0]
return data, seg
[docs]
class DownsampleSegForDSTransform2(AbstractTransform):
"""
Transform that generates downsampled versions of a segmentation map for deep supervision.
This transform stores the results in `data_dict[output_key]` as a list of segmentations, each scaled
according to a corresponding entry in `ds_scales`.
:ivar tuple | List ds_scales: Scaling factors per deep supervision level. Each entry can be a float (same scaling for all axes) or a tuple of floats (individual scaling per axis).
:ivar int order: Interpolation order to use for resizing (0 = nearest neighbor).
:ivar str input_key: Key to access the input segmentation in `data_dict`.
:ivar str output_key: Key under which to store the output list of downsampled segmentations.
:ivar tuple[int] axes: Axes along which to apply the downsampling. If None, assumes axes are (2, 3, 4), i.e., skips batch and channel.
"""
axes:tuple[int]
output_key:str
input_key:str
order:int
ds_scales:list| tuple
[docs]
def __init__(self,
ds_scales: list | tuple,
order: int = 0,
input_key: str = "seg",
output_key: str = "seg",
axes: Optional[tuple[int]] = None):
"""
Transform that generates downsampled versions of a segmentation map for deep supervision.
This transform stores the results in `data_dict[output_key]` as a list of segmentations, each scaled
according to a corresponding entry in `ds_scales`.
Each entry in ds_scales specified one deep supervision
output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
for each axis independently
Parameters
----------
ds_scales : list or tuple
Scaling factors per deep supervision level. Each entry can be a float (same scaling for all axes) or a
tuple of floats (individual scaling per axis).
order : int, default=0
Interpolation order to use for resizing (0 = nearest neighbor).
input_key : str, default="seg"
Key to access the input segmentation in `data_dict`.
output_key : str, default="seg"
Key under which to store the output list of downsampled segmentations.
axes : tuple of int, optional
Axes along which to apply the downsampling. If None, assumes axes are (2, 3, 4), i.e., skips batch and channel.
"""
self.axes = axes
self.output_key = output_key
self.input_key = input_key
self.order = order
self.ds_scales = ds_scales
def __call__(self, **data_dict:dict[str,Any])->dict:
"""
Apply the downsampling to a batch of data.
Parameters
----------
**data_dict : dict
Dictionary containing input data. Must contain `input_key`.
Raises
------
AssertionError:
If a element of ds_scales has not the same length as axes
Returns
-------
**data_dict : dict
Modified `data_dict` with `output_key` storing a list of downsampled segmentations.
"""
if self.axes is None:
axes = list(range(2, len(data_dict[self.input_key].shape)))
else:
axes = self.axes
output = []
for s in self.ds_scales:
if not isinstance(s, (tuple, list)):
s = [s] * len(axes)
else:
assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
f'for each axis) then the number of entried in that tuple (here ' \
f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'
if all(i == 1 for i in s):
output.append(data_dict[self.input_key])
else:
new_shape = np.array(data_dict[self.input_key].shape).astype(float)
for i, a in enumerate(axes):
new_shape[a] *= s[i]
new_shape = np.round(new_shape).astype(int)
out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
for b in range(data_dict[self.input_key].shape[0]):
for c in range(data_dict[self.input_key].shape[1]):
out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)
output.append(out_seg)
data_dict[self.output_key] = output
return data_dict
[docs]
def get_training_transforms(aug_patch_size: np.ndarray| tuple[int],
patch_size: np.ndarray | tuple[int],
fg_rate: float,
rotation_for_DA: dict,
deep_supervision_scales: list | tuple | None,
mirror_axes: tuple[int, ...],
handler:DataHandler,
do_dummy_2d_data_aug: bool,
order_resampling_data: int = 3,
order_resampling_seg: int = 1,
border_val_seg: int = -1,
use_data_reader: bool = True,
) -> AbstractTransform:
"""
Create a composed transform pipeline for training data augmentation, following the nnU-Net conventions.
Parameters
----------
aug_patch_size : numpy.ndarray or tuple of int
Size of the patch used during augmentation (may be larger than `patch_size`).
patch_size : numpy.ndarray or tuple of int
Final cropped patch size used for training.
fg_rate : float
Probability of cropping patches that contain foreground voxels.
rotation_for_DA : dict
Dictionary specifying rotation angles for data augmentation. Should contain keys 'x', 'y', and 'z'.
deep_supervision_scales : list, tuple or None
List of scales for deep supervision. Used to downsample segmentation masks accordingly.
mirror_axes : tuple[int, ...]
Axes along which to apply mirroring (e.g., (0, 1, 2)).
handler : DataHandler
DataHandler used to load images. Used only if use_data_reader is True
do_dummy_2d_data_aug : bool
If True, applies dummy 2D data augmentation (by slicing 3D volumes).
order_resampling_data : int, default=3
Interpolation order used for resampling image data.
order_resampling_seg : int, default=1
Interpolation order used for resampling segmentation masks.
border_val_seg : int, default=-1
Border value used for segmentation padding.
use_data_reader : bool, default=True
If True, includes the DataReader transform in the pipeline.
Returns
-------
AbstractTransform
A composed transformation pipeline to be applied to training data.
"""
tr_transforms = []
if use_data_reader:
tr_transforms.append(DataReader(handler))
tr_transforms.append(nnUNetRandomCropAndPadTransform(aug_patch_size,
patch_size,
fg_rate,
data_key="data",
label_key="seg",
class_loc_key="loc",))
if do_dummy_2d_data_aug:
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
tr_transforms.append(SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=None,
do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'],
p_rot_per_axis=1, # todo experiment with this
do_scale=True, scale=(0.7, 1.4),
border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data,
border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg,
random_crop=False, # random cropping is part of our dataloaders
p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2,
independent_scale_for_each_axis=False # todo experiment with this
))
if do_dummy_2d_data_aug:
tr_transforms.append(Convert2DTo3DTransform())
tr_transforms.append(CenterCropTransform(patch_size, data_key='data', label_key='seg'))
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
p_per_channel=0.5,
order_downsample=0, order_upsample=3, p_per_sample=0.25,
ignore_axes=ignore_axes))
tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))
tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))
if mirror_axes is not None and len(mirror_axes) > 0:
tr_transforms.append(MirrorTransform(mirror_axes))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
if deep_supervision_scales is not None:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms.append(DictToTuple('data', 'target'))
tr_transforms = Compose(tr_transforms)
return tr_transforms
[docs]
def get_validation_transforms(patch_size: np.ndarray | tuple[int],
fg_rate: float,
handler:DataHandler,
deep_supervision_scales: list | tuple | None = None,
use_data_reader: bool = True,
) -> AbstractTransform:
"""
Create a composed transformation pipeline for validation data, following the nnU-Net conventions.
Parameters
----------
patch_size : numpy.ndarray or tuple of int
Size of the patch used for cropping and padding.
fg_rate : float
Probability of focusing on foreground regions when cropping.
handler : DataHandler
DataHandler used to load images. Used only if use_data_reader is True
deep_supervision_scales : list, tuple or None, optional
List of scales for deep supervision. If provided, segmentation masks will be downsampled accordingly.
use_data_reader : bool, default=True
If True, includes the DataReader transform to load data from disk.
Returns
-------
AbstractTransform
A composed transform pipeline to be applied during validation.
"""
val_transforms = []
if use_data_reader:
val_transforms.append(DataReader(handler))
val_transforms.append(nnUNetRandomCropAndPadTransform(patch_size,
patch_size,
fg_rate,
data_key="data",
label_key="seg",
class_loc_key="loc",))
val_transforms.append(RemoveLabelTransform(-1, 0))
val_transforms.append(RenameTransform('seg', 'target', True))
if deep_supervision_scales is not None:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms.append(DictToTuple('data', 'target'))
val_transforms = Compose(val_transforms)
return val_transforms
#---------------------------------------------------------------------------
# dataloader
[docs]
class BatchGenDataLoader(SlimDataLoaderBase):
"""
Similar as torchio.SubjectsDataset but can be use with an unlimited amount of steps.
:ivar str img_path: Path to collection containing the images.
:ivar str msk_path: Path to collection containing the masks.
:ivar str | None fg_path: Path to collection containing foreground information.
:ivar int batch_size: Size of the batches.
:ivar int nbof_steps: Number of steps per epoch.
:ivar numpy.ndarray indices: A array of unsigned int representing the possibles index for images.
:ivar int current_position: Index of the actual image.
:ivar bool was_initialized: If the batch generator was initialized, used for safeguarding.
"""
img_path:str
msk_path:str
fg_path:Optional[str]
batch_size:int
nbof_steps:int
load_data:bool
indices:np.ndarray
current_position:int
was_initialized:bool
[docs]
def __init__(
self,
img_path:str,
msk_path:str,
batch_size:int,
nbof_steps:int,
fg_path:Optional[str] = None,
folds_csv:Optional[str] = None,
fold :int = 0,
val_split:float = 0.25,
train :bool = True,
load_data:bool = False,
# batch_generator parameters
num_threads_in_mt=12,
):
"""
Similar as torchio.SubjectsDataset but can be use with an unlimited amount of steps.
Parameters
----------
img_path : str
Path to collection containing the images.
msk_path : str
Path to collection containing the masks.
batch_size : int
Size of the batches.
nbof_steps : int
Number of steps per epoch.
fg_path : str, optional
Path to collection containing foreground information.
folds_csv : str, optional
CSV file containing fold information for dataset splitting.
fold : int, optional
Current fold number for training/validation splitting.
val_split : float, optional
Proportion of data to be used for validation.
train : bool, optional
If True, use the dataset for training; otherwise, use it for validation.
load_data : bool, optional
if True, loads the all dataset into computer memory (faster but more memory expensive). ONLY COMPATIBLE WITH .npy PREPROCESSED IMAGES
num_threads_in_mt : int, optional
Number of threads in multi-threaded augmentation.
"""
self.img_path = img_path
self.msk_path = msk_path
self.fg_path = fg_path
self.batch_size = batch_size
self.nbof_steps = nbof_steps
self.load_data = load_data
handler = DataHandlerFactory.get(
self.img_path,
read_only=True,
msk_path = msk_path,
fg_path = fg_path,
)
# get the training and validation names
if folds_csv is not None:
df = pd.read_csv(folds_csv)
trainset, testset = get_folds_train_test_df(df, verbose=False)
self.fold = fold
self.val_imgs = trainset[self.fold]
del trainset[self.fold]
self.train_imgs = []
for i in trainset: self.train_imgs += i
else: # tmp: validation split = 50% by default
all_set = handler.extract_inner_path(handler.images)
val_split = np.round(val_split * len(all_set)).astype(int)
if val_split == 0: val_split=1
self.train_imgs = all_set[val_split:]
self.val_imgs = all_set[:val_split]
testset = []
self.train = train
if self.train:
print("current fold: {}\n \
length of the training set: {}\n \
length of the validation set: {}\n \
length of the testing set: {}".format(fold, len(self.train_imgs), len(self.val_imgs), len(testset)))
self.fnames = self.train_imgs if self.train else self.val_imgs
handler.open(
img_path = img_path,
msk_path = msk_path,
fg_path = fg_path,
img_inner_paths_list = self.fnames,
msk_inner_paths_list = self.fnames,
fg_inner_paths_list = [f[:f.find('.')]+'.pkl' for f in self.fnames],
)
# print train and validation image names
print("{} images: {}".format("Training" if self.train else "Validation", self.fnames))
def generate_data(handler:DataHandler)->list[dict[str,np.ndarray|str]]:
"""Load data, if self.load_data is False, it will only load their path."""
data=[]
nonlocal load_data
if load_data:
for i,m,f in handler:
fg=None
# file names
img = handler.load(i)[0]
msk = handler.load(m)[0]
if self.fg_path is not None:
fg = handler.load(f)[0]
data += [{'data': img, 'seg': msk, 'loc': fg}]
else:
for i,m,f in handler:
data += [{'data': i, 'seg': m, 'loc': f}]
return data
data = generate_data(handler)
super(BatchGenDataLoader, self).__init__(
data,
batch_size,
num_threads_in_mt,
)
self.indices = np.arange(len(self._data))
self.current_position = 0
self.was_initialized = False
[docs]
def reset(self)->None:
"""
Reset the internal state of the batch generator.
Resets the current position in the epoch and marks the generator as initialized.
Raises
------
AssertionError
If `self.indices` is not set.
Returns
-------
None
"""
assert self.indices is not None
self.current_position = 0
self.was_initialized = True
[docs]
def get_indices(self)->np.ndarray:
"""
Retrieve a random batch of indices from the dataset.
Returns
-------
numpy.ndarray
A NumPy array of randomly sampled indices with shape (batch_size,).
Raises
------
StopIteration
If the number of allowed steps per epoch is exceeded.
"""
if not self.was_initialized:
self.reset()
indices = np.random.choice(self.indices, self.batch_size, replace=True)
if self.current_position < self.nbof_steps:
self.current_position += self.number_of_threads_in_multithreaded
return indices
else:
self.was_initialized=False
raise StopIteration
[docs]
def generate_train_batch(self)->dict:
"""
Generate a training batch from the dataset.
Returns
-------
dict
A dictionary with the following keys:
- 'data': List of input data arrays for the batch.
- 'seg': List of corresponding segmentation masks.
- 'loc': List of class location dictionaries (for foreground sampling, etc).
"""
indices = self.get_indices()
batch_list = [self._data[i] for i in indices]
batch = {
'data':[data['data'] for data in batch_list],
'seg': [data['seg'] for data in batch_list],
'loc': [data['loc'] for data in batch_list],
}
return batch
#---------------------------------------------------------------------------
# multi-threading
[docs]
def get_patch_size(final_patch_size:list[int]| tuple[int]| np.ndarray,
rot_x:float|tuple[float]|list[float],
rot_y:float|tuple[float]|list[float],
rot_z:float|tuple[float]|list[float],
scale_range:tuple[float]|list[float],
)->np.ndarray:
"""
Compute the required patch size to accommodate rotation and scaling augmentations.
This function determines the maximum patch size needed after applying possible
rotations and scaling to ensure that the original patch fits entirely within the
transformed space (i.e., no cropping due to rotation).
Parameters
----------
final_patch_size : list/tuple/ndarray of int
The desired final patch size before any augmentations.
Should be 2D (for 2D images) or 3D (for volumetric data).
rot_x : float or tuple/list of float
Rotation angle(s) in radians around the x-axis.
If a tuple or list, the maximum absolute value is used.
rot_y : float or tuple/list of float
Rotation angle(s) in radians around the y-axis.
Ignored if input is 2D.
rot_z : float or tuple/list of float
Rotation angle(s) in radians around the z-axis.
Ignored if input is 2D.
scale_range : tuple or list of float
Range of possible scaling factors applied during augmentation.
The minimum value is used to compute the worst-case required patch size.
Returns
-------
final_shape: numpy.ndarray of int
The adjusted patch size that ensures the transformed patch still contains
the original field of view, accounting for rotation and scaling.
Notes
-----
- The maximum allowed rotation is clipped to 90° (π/2 radians) for numerical stability.
- The patch size is increased to accommodate potential rotation "corners"
that extend beyond the original bounds.
"""
if isinstance(rot_x, (tuple, list)):
rot_x = max(np.abs(rot_x))
if isinstance(rot_y, (tuple, list)):
rot_y = max(np.abs(rot_y))
if isinstance(rot_z, (tuple, list)):
rot_z = max(np.abs(rot_z))
rot_x = min(90 / 360 * 2. * np.pi, rot_x)
rot_y = min(90 / 360 * 2. * np.pi, rot_y)
rot_z = min(90 / 360 * 2. * np.pi, rot_z)
coords = np.array(final_patch_size)
final_shape = np.copy(coords)
if len(coords) == 3:
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
elif len(coords) == 2:
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
final_shape /= min(scale_range)
return final_shape.astype(int)
[docs]
def configure_rotation_dummy_da_mirroring_and_inital_patch_size(patch_size:Iterable[int],
)->tuple[dict[str, tuple[float, float]], bool, np.ndarray, tuple[int, ...]]:
"""
Configure rotation parameters, dummy 2D data augmentation, mirroring axes, and compute the initial patch size.
This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
Parameters
----------
patch_size: iterabloe of int
Patch size as a tuple, array, list,...
Raises
------
RuntimeError:
If patch_size not in 2 or 3 dimension
Returns
-------
rotation_for_DA: dict of str to tuple of float
A rotation for data augmentation.
do_dummy_2d_data_aug: bool
Whether a dummy 2d data augmentation has been done
initial_patch_size : numpy.ndarray
Path to foregrounds output collection.
"""
patch_size=np.array(patch_size)
dim = len(patch_size)
# TODO rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
if dim == 2:
do_dummy_2d_data_aug = False
# TODO revisit this parametrization
if max(patch_size) / min(patch_size) > 1.5:
rotation_for_DA = {
'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
'y': (0, 0),
'z': (0, 0)
}
else:
rotation_for_DA = {
'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
'y': (0, 0),
'z': (0, 0)
}
mirror_axes = (0, 1)
elif dim == 3:
# TODO this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
# order of the axes is determined by spacing, not image size
do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > 3
if do_dummy_2d_data_aug:
# why do we rotate 180 deg here all the time? We should also restrict it
rotation_for_DA = {
'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
'y': (0, 0),
'z': (0, 0)
}
else:
rotation_for_DA = {
'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
}
mirror_axes = (0, 1, 2)
else:
raise RuntimeError(f"Patch must be of 3 or 2 dimension, found : '{dim}'")
# TODO this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
# old nnunet for now)
initial_patch_size = get_patch_size(patch_size[-dim:],
*rotation_for_DA.values(),
(0.85, 1.25))
if do_dummy_2d_data_aug:
initial_patch_size[0] = patch_size[0]
print(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
[docs]
class MTBatchGenDataLoader(MultiThreadedAugmenter):
"""
Multi-threaded data loader for efficient data augmentation and loading.
:ivar int length: Number of batches.
"""
[docs]
def __init__(
self,
img_path:str,
msk_path:str,
patch_size:Iterable[int],
batch_size:int,
nbof_steps:int,
fg_path:Optional[str] = None,
folds_csv:Optional[str] = None,
fold :int = 0,
val_split:float = 0.25,
train:bool = True,
load_data:bool = False,
fg_rate:float = 0.33,
num_threads_in_mt:int=12,
**kwargs,
):
"""
Multi-threaded data loader for efficient data augmentation and loading.
Parameters
----------
img_path : str
Path to a collection containing the images.
msk_path : str
Path to a collection containing the masks.
patch_size : iterable of int
The size of the patches to be extracted.
batch_size : int
Size of the batches.
nbof_steps : int
Number of steps per epoch.
fg_path : str, optional
Path to a collection containing foreground information. For the moment it is not optional (need to fix that).
folds_csv : str, optional
CSV file containing fold information for dataset splitting.
fold : int, default=0
Current fold number for training/validation splitting.
val_split : float, default=0.25
Proportion of data to be used for validation.
train : bool, default=True
If True, use the dataset for training; otherwise, use it for validation.
load_data : bool, default=False
If True, loads the entire dataset into computer memory.
fg_rate : float, default=0.33
Foreground rate for cropping.
num_threads_in_mt : int, default=12
Number of threads in multi-threaded augmentation.
**kwargs
Just to handle other parameters.
Raises
------
ValueError:
If fg_path is None
"""
if fg_path is None : raise ValueError("Batchgen module need foregrounds, ensure the preprocessing does it and that the path is included in the config file.")
gen = BatchGenDataLoader(
img_path,
msk_path,
batch_size,
nbof_steps,
fg_path,
folds_csv,
fold,
val_split,
train,
load_data,
# batch_generator parameters
num_threads_in_mt,
)
self.length = nbof_steps
handler = DataHandlerFactory.get(
img_path,
read_only=True,
msk_path = msk_path,
fg_path = fg_path,
)
if train:
rotation_for_DA, do_dummy_2d_data_aug, aug_patch_size, mirror_axes=configure_rotation_dummy_da_mirroring_and_inital_patch_size(patch_size)
transform = get_training_transforms(
aug_patch_size=aug_patch_size,
patch_size=patch_size,
fg_rate = fg_rate,
rotation_for_DA=rotation_for_DA,
deep_supervision_scales=None,
mirror_axes=mirror_axes,
handler=handler,
do_dummy_2d_data_aug=do_dummy_2d_data_aug,
use_data_reader=not load_data,
)
else:
transform = get_validation_transforms(
patch_size=patch_size,
handler=handler,
fg_rate = fg_rate,
use_data_reader=not load_data,
)
super(MTBatchGenDataLoader, self).__init__(
gen,
transform,
num_threads_in_mt,
batch_size
)
def __len__(self)->int:
"""Return the number of batches in the batch generator."""
return self.length
#---------------------------------------------------------------------------