Source code for biom3d.datasets.semseg_patch_fast

"""Dataset primitives for 3D segmentation dataset. Solution: patch approach with the whole dataset into memory."""


from typing import Any, Iterable, Optional
import numpy as np 
import torchio as tio
import random 
from torch.utils.data import Dataset
import pandas as pd 
from biom3d.utils import centered_pad, get_folds_train_test_df, DataHandlerFactory, DataHandler

#---------------------------------------------------------------------------
# utilities to random crops

[docs] def centered_crop(img:np.ndarray, msk:np.ndarray, center:Iterable[int], crop_shape:Iterable[int], margin:Iterable[float]=np.zeros(3), )->tuple[np.ndarray,np.ndarray]: """ Do a crop, forcing the location voxel to be located in the center of the crop. Parameters ---------- img: numpy.ndarray Image data. msk: ndaarray Mask data. center: iterable of int Center voxel location for cropping. crop_shape: iterable of int Shape of the crop. margin: iterable of float, default np.zeros(3) Margin around the center location. Raises ------ AssertionError If center is out of the image Returns ------- crop_img : numpy.ndarray Cropped image data. crop_msk : numpy.ndarray Cropped mask data. """ img_shape = np.array(img.shape)[1:] center = np.array(center) assert np.all(center>=0) and np.all(center<img_shape), "[Error] Center must be located inside the image. Center: {}, Image shape: {}".format(center, img_shape) crop_shape = np.array(crop_shape) margin = np.array(margin) # middle of the crop start = (center-crop_shape//2+margin).astype(int) # assert that the end will not be out of the crop end = start+crop_shape # we make sure that we are not out of the image shape start = np.maximum(0,start) idx = [slice(0,img.shape[0])]+[slice(s[0], s[1]) for s in zip(start, end)] idx = tuple(idx) crop_img = img[idx] crop_msk = msk[idx] return crop_img, crop_msk
[docs] def located_crop(img:np.ndarray, msk:np.ndarray, location:Iterable[int], crop_shape:Iterable[int], margin:Iterable[float]=np.zeros(3), )->tuple[np.ndarray,np.ndarray]: """ Do a crop, forcing the location voxel to be located in the crop. Parameters ---------- img : numpy.ndarray Image data. msk : numpy.ndarray Mask data. location : iterable of int Specific voxel location to include in the crop. crop_shape : iterable of int Shape of the crop. margin : iterable of float Margin around the location. Returns ------- crop_img : numpy.ndarray Cropped image data. crop_msk : numpy.ndarray Cropped mask data. """ img_shape = np.array(img.shape)[1:] location = np.array(location) crop_shape = np.array(crop_shape) margin = np.array(margin) lower_bound = np.maximum(0,location-crop_shape+margin) higher_bound = np.maximum(lower_bound+1,np.minimum(location-margin, img_shape-crop_shape)) start = np.random.randint(low=lower_bound, high=np.maximum(lower_bound+1,higher_bound)) end = start+crop_shape idx = [slice(0,img.shape[0])]+[slice(s[0], s[1]) for s in zip(start, end)] idx = tuple(idx) crop_img = img[idx] crop_msk = msk[idx] return crop_img, crop_msk
[docs] def foreground_crop(img:np.ndarray, msk:np.ndarray, final_size:Iterable[int], fg_margin:Iterable[float], fg:Optional[dict[int,np.ndarray]]=None, use_softmax:bool=True, )->tuple[np.ndarray,np.ndarray]: """Do a foreground crop. Parameters ---------- img : numpy.ndarray Image data. msk : numpy.ndarray Mask data. final_size : iterable of int Final size of the cropped image and mask. fg_margin : iterable of float Margin around the foreground location. fg : dict of int to numpy.ndarray, optional Foreground information. use_softmax : bool, default=True If True, assumes softmax activation. Returns ------- img : numpy.ndarray Cropped image data, focused on the foreground region. msk : numpy.ndarray Cropped mask data, corresponding to the cropped image region. """ if fg is not None and len(list(fg.keys()))>0: locations = fg[random.choice(list(fg.keys()))] else: if tuple(msk.shape)[0]==1: # then we consider that we don't have a one hot encoded label rnd_label = random.randint(1,msk.max() if msk.max()>0 else 1) locations = np.argwhere(msk[0] == rnd_label) else: # then we have a one hot encoded label rnd_label = random.randint(int(use_softmax),tuple(msk.shape)[0]-1) locations = np.argwhere(msk[rnd_label] == 1) if np.array(locations).size==0: # bug fix when having empty arrays img, msk = random_crop(img, msk, final_size, force_in=False) else: center=random.choice(locations) # choose a random voxel of this label img, msk = centered_crop(img, msk, center, final_size, fg_margin) return img, msk
[docs] def centered_pad(img:np.ndarray, final_size:Iterable[int], msk:Optional[np.ndarray]=None, )->np.ndarray|tuple[np.ndarray,np.ndarray]: """ Do a centered pad an img and msk to fit the final_size. Parameters ---------- img : numpy.ndarray Image data. final_size : iterable of int Final size of the cropped image and mask. msk : numpy.ndarray, optional Mask data. Returns ------- pad_img : numpy.ndarray Cropped image data, focused on the foreground region. pad_msk : numpy.ndarray, optional Cropped mask data, corresponding to the cropped image region. """ final_size = np.array(final_size) img_shape = np.array(img.shape[1:]) start = (final_size-np.array(img_shape))//2 start = start * (start > 0) end = final_size-(img_shape+start) end = end * (end > 0) pad = np.append([[0,0]], np.stack((start,end),axis=1), axis=0) pad_img = np.pad(img, pad, 'constant', constant_values=0) if msk is not None: pad_msk = np.pad(msk, pad, 'constant', constant_values=0) return pad_img, pad_msk else: return pad_img
[docs] def random_crop(img:np.ndarray, msk:np.ndarray, crop_shape:Iterable[int], force_in:bool=True, )->tuple[np.ndarray,np.ndarray]: """ Randomly crop a portion of size prop of the original image size. Parameters ---------- img : numpy.ndarray Image data. msk : numpy.ndarray Mask data. crop_shape : Shape of the crop. force_in : bool, optional If True, ensures the crop is fully within the image boundaries. Raises ------ AssertionError If image shape (minus C) is not the same shape as crop_shape Returns ------- crop_img : numpy.ndarray Cropped image data. crop_msk : numpy.ndarray Cropped mask data. """ img_shape = np.array(img.shape)[1:] assert len(img_shape)==len(crop_shape),"[Error] Not the same dimensions! Image shape {}, Crop shape {}".format(img_shape, crop_shape) if force_in: # force the crop to be located in image shape range start = np.random.randint(0, np.maximum(1,img_shape-crop_shape)) end = start+crop_shape idx = [slice(0,img.shape[0])]+[slice(s[0], s[1]) for s in zip(start, end)] idx = tuple(idx) crop_img = img[idx] crop_msk = msk[idx] else: # the crop will be chosen randomly and then padded if needed # the crop might be too small but will be padded with zeros start = np.random.randint(0, img_shape) crop_img, crop_msk = centered_crop(img=img, msk=msk, center=start, crop_shape=crop_shape) # pad if needed if np.any(np.array(crop_img.shape)[1:]-crop_shape)!=0: crop_img, crop_msk = centered_pad(img=crop_img, msk=crop_msk, final_size=crop_shape) return crop_img, crop_msk
[docs] def random_crop_pad(img:np.ndarray, msk:np.ndarray, final_size:Iterable[int], fg_rate:float=0.33, fg_margin:Iterable[float]=np.zeros(3), fg:Optional[dict[int,np.ndarray]]=None, use_softmax:bool=True, )->tuple[np.ndarray,np.ndarray]: """ Do a random crop and pad if needed. Parameters ---------- img : numpy.ndarray Image data. msk : numpy.ndarray Mask data. final_size : Final size of the image and mask after cropping and padding. fg_rate : float, default=0.33 Probability of focusing the crop on the foreground. fg_margin : iterable of float, default=np.zeros(3) Margin around the foreground location. fg : dict of int to numpy.ndarray, optional Foreground information. use_softmax : bool, default=True If True, assumes softmax activation. Returns ------- img : numpy.ndarray Cropped and padded image data. msk : numpy.ndarray Cropped and padded mask data. """ if isinstance(img,list): # then batch mode imgs, msks = [], [] for i in range(len(img)): img_, msk_ = random_crop_pad(img[i], msk[i], final_size) imgs += [img_] msks += [msk_] return np.array(imgs), np.array(msks) # can convert to array as they should have now all the same shape # choose if using foreground centrered or random alignement force_fg = random.random() if fg_rate>0 and force_fg<fg_rate: img, msk = foreground_crop(img, msk, final_size, fg_margin, fg=fg, use_softmax=use_softmax) else: # or random crop img, msk = random_crop(img, msk, final_size, force_in=False) # pad if needed if np.any(np.array(img.shape)[1:]-final_size)!=0: img, msk = centered_pad(img=img, msk=msk, final_size=final_size) return img, msk
[docs] def random_crop_resize(img:np.ndarray, msk:np.ndarray, crop_scale:float, final_size:Iterable[int], fg_rate:int=0.33, fg_margin:Iterable[float]=np.zeros(3), )->tuple[np.ndarray,np.ndarray]: """ Do a random crop and resize if needed. Parameters ---------- img : numpy.ndarray Image data. msk : numpy.ndarray Mask data. crop_scale : float, >=1 Scale factor for the crop size. final_size : iterable of int Final size of the image and mask after cropping and resizing. fg_rate : float, default=0.33 Probability of focusing the crop on the foreground. fg_margin : iterable of float, default=np.zeros(3) Margin around the foreground location. Raises ------ ValueError If crop_scale < 1. Returns ------- img : numpy.ndarray Cropped and resized image data. msk : numpy.ndarray Cropped and resized mask data. """ final_size = np.array(final_size) if crop_scale < 1 :raise ValueError(f"Crop scale must be a float >1, found '{crop_scale}'") # determine crop shape max_crop_shape = np.round(final_size * crop_scale).astype(int) crop_shape = np.random.randint(final_size, max_crop_shape+1) # choose if using foreground centrered or random alignement force_fg = random.random() if fg_rate>0 and force_fg<fg_rate: rnd_label = random.randint(0,msk.shape[0]-1) # choose a random label locations = np.argwhere(msk[rnd_label] == 1) if locations.size==0: # bug fix when having empty arrays img, msk = random_crop(img, msk, crop_shape) else: center=random.choice(locations) # choose a random voxel of this label img, msk = located_crop(img, msk, center, crop_shape, fg_margin) else: # or random crop img, msk = random_crop(img, msk, crop_shape) # resize if needed if np.any(np.array(img.shape)[1:]-final_size)>0: sub = tio.Subject(img=tio.ScalarImage(tensor=img), msk=tio.LabelMap(tensor=msk)) sub = tio.Resize(final_size)(sub) img, msk = sub.img.tensor, sub.msk.tensor elif np.any(np.array(img.shape)[1:]-final_size)<0: img, msk = centered_pad(img, msk, final_size) return img, msk
#---------------------------------------------------------------------------
[docs] class LabelToLong: """ Transform to convert label data to long (integer) type. :ivar str label_name: Name of the label to be transformed. """ label_name:str
[docs] def __init__(self, label_name:str): """ Transform to convert label data to long (integer) type. Parameters ---------- label_name : str Name of the label to be transformed. Returns ------- subject : dict Dictionary with the label data converted to long (integer) type. """ self.label_name = label_name
def __call__(self, subject:dict[str,Any])->dict[str,Any]: """ Transform to convert label data to long (integer) type. Parameters ---------- subject : dict of string to any Dictionary that associate label name to values, should contains self.label_name Returns ------- subject : dict Dictionary with the label data converted to long (integer) type. """ if self.label_name in subject.keys(): subject[self.label_name].set_data(subject[self.label_name].data.long()) return subject
#---------------------------------------------------------------------------
[docs] class SemSeg3DPatchFast(Dataset): """ Dataset class for semantic segmentation with 3D patches. Supports data augmentation and efficient loading. :ivar str img_path: Path to collection containing the image files. :ivar str msk_path: Path to collection containing the mask files. :ivar str | None fg_path: Path to collection containing the foreground files. :ivar int batch_size: Size of a batch. :ivar numpy.ndarray patch_size: Size of a patch. :ivar numpy.ndarray | None aug_patch_size: Size of augmented patch size, may be bigger than patch size. :ivar int nbof_steps: Number of steps (batches) per epoch. :ivar bool load_data: If True, load the entire dataset into memory. :ivar DataHandler handler: DataHandler used to load data. :ivar bool train: If True, use the dataset for training; otherwise, use it for validation. :ivar list[str] fnames: List of image paths relative to img_path. :ivar bool use_aug: Whether to use data augmentation :ivar float fg_rate: Foreground rate, used to force foreground inclusion in patches. :ivar float crop_scale: Scale factor for crop size during augmentation. :ivar bool use_softmax: If True, use softmax activation. :ivar int batch_idx: Current batch index. """ img_path:str msk_path:str fg_path:str batch_size:int patch_size:np.ndarray aug_patch_size:bool nbof_steps:int load_data:bool handler:DataHandler train:bool fnames:list[str] use_aug:bool fg_rate:float crop_scale:float use_softmax:bool batch_idx:int
[docs] def __init__( self, img_path:str, msk_path:str, batch_size:int, patch_size:np.ndarray, nbof_steps:int, folds_csv:Optional[str] = None, fold:int = 0, val_split:float = 0.25, train:bool = True, use_aug:bool = True, aug_patch_size:Optional[np.ndarray] = None, fg_path:Optional[str] = None, fg_rate:float = 0.33, crop_scale:float = 1.0, load_data:bool = False, use_softmax:bool = True, ): """ Dataset class for semantic segmentation with 3D patches. Supports data augmentation and efficient loading. Parameters ---------- img_path : str Path to collection containing the image files. msk_path : str Path to collection containing the mask files. batch_size : int Batch size for dataset sampling. patch_size : numpy.ndarray Size of the patches to be used. nbof_steps : int Number of steps (batches) per epoch. folds_csv : str, optional CSV file containing fold information for dataset splitting. fold : int, default=0 The current fold number for training/validation splitting. val_split : float, default=0.25 Proportion of data to be used for validation. train : bool, default=True If True, use the dataset for training; otherwise, use it for validation. use_aug : bool, default=True If True, apply data augmentation. aug_patch_size : numpy.ndarray, optional Patch size to use for augmented patches. fg_path : str, optional Path to collection containing foreground information. fg_rate : float, default=0.33 Foreground rate, used to force foreground inclusion in patches. If > 0, force the use of foreground, needs to run some pre-computations (note: better use the foreground scheduler) crop_scale : float, default=1.0 Scale factor for crop size during augmentation. If > 1, then use random_crop_resize instead of random_crop_pad load_data : bool, default=False If True, load the entire dataset into memory. use_softmax : bool, default=True If True, use softmax activation. Raises ------ AssertionError If fold_csv is None and not valid path for datas, or empty collections If crop_scale < 1 """ self.img_path = img_path self.msk_path = msk_path self.fg_path = fg_path self.batch_size = batch_size self.patch_size = patch_size self.aug_patch_size = aug_patch_size self.nbof_steps = nbof_steps self.load_data = load_data self.handler = DataHandlerFactory.get( self.img_path, read_only=True, msk_path = msk_path, fg_path = fg_path, ) # get the training and validation names if folds_csv is not None: df = pd.read_csv(folds_csv) trainset, testset = get_folds_train_test_df(df, verbose=False) self.fold = fold self.val_imgs = trainset[self.fold] del trainset[self.fold] self.train_imgs = [] for i in trainset: self.train_imgs += i else: all_set = self.handler.extract_inner_path(self.handler.images) assert len(all_set) > 0, "[Error] Incorrect path for folder of images or your folder is empty." np.random.shuffle(all_set) # shuffle all_set val_split = np.round(val_split * len(all_set)).astype(int) if val_split == 0: val_split=1 self.train_imgs = all_set[val_split:] self.val_imgs = all_set[:val_split] testset = [] self.train = train if self.train: print("current fold: {}\n \ length of the training set: {}\n \ length of the validation set: {}\n \ length of the testing set: {}".format(fold, len(self.train_imgs), len(self.val_imgs), len(testset))) self.fnames = self.train_imgs if self.train else self.val_imgs self.handler.open( img_path = img_path, msk_path = msk_path, fg_path = fg_path, img_inner_paths_list = self.fnames, msk_inner_paths_list = self.fnames, fg_inner_paths_list = [f[:f.find('.')]+'.pkl' for f in self.fnames], ) # print train and validation image names print("{} images: {}".format("Training" if self.train else "Validation", self.fnames)) if self.load_data: print("Loading the whole dataset into computer memory...") def load_data(): nonlocal fg_path imgs_data = [] msks_data = [] fg_data = [] for i,m,f in self.handler: # load img and msks imgs_data += [self.handler.load(i)[0]] msks_data += [self.handler.load(m)[0]] # load foreground if fg_path is not None: fg_data += [self.handler.load(f)[0]] return imgs_data, msks_data, fg_data self.imgs_data, self.msks_data, self.fg_data = load_data() print("Done!") self.use_aug = use_aug if self.use_aug: ps = np.array(self.patch_size) # [aug] flipping probabilities flip_prop=ps.min()/ps flip_prop/=flip_prop.sum() # [aug] 'axes' for tio.RandomAnisotropy anisotropy_axes=tuple(np.arange(len(ps))[ps/ps.min()>3].tolist()) if len(anisotropy_axes)==0: anisotropy_axes=tuple(np.arange(len(ps)).tolist()) # [aug] 'degrees' for tio.RandomAffine if np.any(ps/ps.min()>3): # then use dummy_2d degrees = tuple(180 if p==ps.argmin() else 0 for p in range(len(ps))) else: degrees = (-45,45) # [aug] 'cropping' # the affine transform is computed on bigger patches than the other transform # that's why we need to crop the patch after potential affine augmentation start = (np.array(self.aug_patch_size)-np.array(self.patch_size))//2 end = self.aug_patch_size-(np.array(self.patch_size)+start) cropping = (start[0],end[0],start[1],end[1],start[2],end[2]) # the foreground-crop-function forces the foreground to be in the center of the patch # so that, when doing the second centrering crop, the foreground is still present in the patch, # that's why there is a margin here self.fg_margin = np.zeros(len(patch_size)) self.rotate = tio.Compose([ tio.RandomAffine(scales=0, degrees=degrees, translation=0, default_pad_value=0), tio.Crop(cropping=cropping), LabelToLong(label_name='msk')]) self.transform = tio.Compose([ tio.Compose([tio.RandomAffine(scales=(0.7,1.4), degrees=0, translation=0), LabelToLong(label_name='msk') ], p=0.2), # spatial augmentations tio.RandomAnisotropy(p=0.1, axes=anisotropy_axes, downsampling=(1,1.5)), tio.RandomFlip(p=1, axes=(0,1,2) if self.patch_size[0]!=1 else (1,2)), tio.RandomBiasField(p=0.15, coefficients=0.2), tio.RandomBlur(p=0.2, std=(0.5,1)), tio.RandomNoise(p=0.2, std=(0,0.1)), tio.RandomSwap(p=0.2, patch_size=ps//8), tio.RandomGamma(p=0.3, log_gamma=(-0.35,0.4)), ]) self.fg_rate = fg_rate self.crop_scale = crop_scale assert self.crop_scale >= 1, "[Error] crop_scale must be higher or equal to 1" self.use_softmax = use_softmax self.batch_idx = 0
[docs] def set_fg_rate(self,value:float)->None: """Setter function for the foreground rate class parameter.""" self.fg_rate = value
def _do_fg(self)->bool: """ Determine whether to force the foreground depending on the batch idx. Returns ------- bool True if batch_index >= batch_size * (1-fg_rate) """ return self.batch_idx >= round(self.batch_size * (1 - self.fg_rate)) def _update_batch_idx(self)->None: """Increment batch index, modulo batch_size.""" self.batch_idx += 1 if self.batch_idx >= self.batch_size: self.batch_idx = 0 def __len__(self)->int: """Return nbof_step*batch_size.""" return self.nbof_steps*self.batch_size def __getitem__(self, idx:int)->tuple[np.ndarray,np.ndarray]: """ Return image and mask associated with index, with padding/croping, and data augmentation if use_data_aug. Parameters ---------- idx: int The index of the wanted data. """ if self.load_data: img = self.imgs_data[idx%len(self.imgs_data)] msk = self.msks_data[idx%len(self.msks_data)] if len(self.fg_data)>0: fg = self.fg_data[idx%len(self.fg_data)] else: fg = None else: idx=idx%len(self.fnames) # read the images img = self.handler.load(self.handler.images[idx])[0] msk = self.handler.load(self.handler.masks[idx])[0] # read foreground data if self.fg_path is not None: fg = self.handler.load(self.handler.fg[idx])[0] else: fg = None # random crop and pad # rotation augmentation requires a larger patch size do_rot = random.random() < 0.2 # rotation proba = 0.2 final_size = self.aug_patch_size if self.use_aug and do_rot else self.patch_size fg_margin = self.fg_margin if self.use_aug else np.zeros(3) if self.train and self.crop_scale > 1: img, msk = random_crop_resize( img, msk, crop_scale=self.crop_scale, final_size=final_size, fg_rate=self.fg_rate, fg_margin=fg_margin, ) else: img, msk = random_crop_pad( img, msk, final_size=final_size, fg_rate=int(self._do_fg()), fg_margin=fg_margin, fg = fg, use_softmax=self.use_softmax ) self._update_batch_idx() # data augmentation if self.use_aug: sub = tio.Subject(img=tio.ScalarImage(tensor=img), msk=tio.LabelMap(tensor=msk)) if do_rot: sub = self.rotate(sub) sub = self.transform(sub) img, msk = sub.img.tensor, sub.msk.tensor # to float for msk msk = msk.float() else: # convert mask to float for validation msk = msk.astype(float) return img, msk
#---------------------------------------------------------------------------