Source code for biom3d.utils.filtering

"""This submodule contains thresolding and filter that are used in post processing."""

from skimage import measure
import numpy as np

[docs] def compute_otsu_criteria(im:np.ndarray, th:float)->float: """ Compute the Otsu criteria value for a given threshold on the image. This function implements the core step of Otsu's method, which evaluates the within-class variance weighted by class probabilities for a specific threshold. The goal is to find the threshold minimizing this weighted variance. Found here: https://en.wikipedia.org/wiki/Otsu%27s_method. Parameters ---------- im : numpy.ndarray Grayscale input image as a 2D numpy array. th : float Threshold value to evaluate. Returns ------- float Weighted sum of variances for the two classes separated by the threshold. Returns `np.inf` if one class is empty (to ignore this threshold). """ # create the thresholded image thresholded_im = np.zeros(im.shape) thresholded_im[im >= th] = 1 # compute weights nb_pixels = im.size nb_pixels1 = np.count_nonzero(thresholded_im) weight1 = nb_pixels1 / nb_pixels weight0 = 1 - weight1 # if one of the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered # in the search for the best threshold if weight1 == 0 or weight0 == 0: return np.inf # find all pixels belonging to each class val_pixels1 = im[thresholded_im == 1] val_pixels0 = im[thresholded_im == 0] # compute variance of these classes var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0 var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0 return weight0 * var0 + weight1 * var1
[docs] def otsu_thresholding(im:np.ndarray)->float: """ Compute the optimal threshold for an image using Otsu's method. This function searches for the threshold value that minimizes the weighted within-class variance of the thresholded image. Parameters ---------- im : numpy.ndarray Grayscale input image as a 2D numpy array. Returns ------- float Optimal threshold value computed using Otsu's method. """ threshold_range = np.linspace(im.min(), im.max()+1, num=255) criterias = [compute_otsu_criteria(im, th) for th in threshold_range] best_th = threshold_range[np.argmin(criterias)] return best_th
[docs] def dist_vec(v1:np.ndarray,v2:np.ndarray)->float: """ Euclidean distance between two vectors (np.array). Parameters ---------- v1 : numpy.ndarray Vector 1 v2 : numpy.ndarray Vector 2 Returns ------- float Euclidean distance between v1 and v2. """ v = v2-v1 return np.sqrt(np.sum(v*v))
[docs] def center(labels:np.ndarray, idx:int)->np.ndarray: """ Compute the barycenter of pixels belonging to a specific label. Parameters ---------- labels : numpy.ndarray Label image array where each pixel has an integer label. idx : int Label index for which to compute the barycenter. Returns ------- numpy.ndarray Coordinates of the barycenter as a 1D array (e.g. [y, x] or [z, y, x] depending on dimensions). If no pixels with the given label are found, returns an empty array. """ return np.mean(np.argwhere(labels == idx), axis=0)
[docs] def closest(labels:np.ndarray, num:int)->int: """ Find the label index of the object closest to the center of the image. The function computes the barycenter of all objects (labels 1 to num), then returns the label of the object whose barycenter is closest to the image center. Parameters ---------- labels : numpy.ndarray Label image array where each pixel has an integer label. num : int Number of labels (excluding background) to consider. Returns ------- int The label index (1-based) of the object closest to the image center. Returns 1 if no objects are found. """ labels_center = np.array(labels.shape)/2 centers = [center(labels,idx+1) for idx in range(num)] dist = [dist_vec(labels_center,c) for c in centers] # bug fix, return 1 if dist is empty: if len(dist)==0: return 1 else: return np.argmin(dist)+1
[docs] def keep_center_only(msk:np.ndarray)->np.ndarray: """ Keep only the connected component in the mask that is closest to the image center. Parameters ---------- msk : numpy.ndarray Binary mask (2D or 3D) where connected components are to be analyzed. Returns ------- numpy.ndarray Mask with only the connected component closest to the center. The returned mask has the same dtype as input, with values 0 or 255. """ labels, num = measure.label(msk, background=0, return_num=True) close_idx = closest(labels,num) return (labels==close_idx).astype(msk.dtype)*255
[docs] def volumes(labels:np.ndarray)->np.ndarray: """ Compute the volume (pixel or voxel count) of each label in the label image. Parameters ---------- labels : numpy.ndarray Label image array where each pixel has an integer label. Returns ------- numpy.ndarray Array of counts of pixels per label, sorted by label index ascending. """ return np.unique(labels, return_counts=True)[1]
[docs] def keep_big_volumes(msk:np.ndarray, thres_rate:float=0.3)->np.ndarray: """ Return a mask keeping only the largest connected components based on a volume threshold. The threshold is computed as: min_volume = thres_rate * otsu_thresholding(volumes) where `volumes` are the sizes of all connected components (excluding background), and `otsu_thresholding` finds an adaptive threshold on the volumes distribution. Parameters ---------- msk : numpy.ndarray Input binary mask. thres_rate : float, default=0.3 Multiplier for the threshold on volumes. Returns ------- numpy.ndarray Mask with only the connected components whose volume is greater than the threshold. Background remains zero. """ # transform image to label labels, num = measure.label(msk, background=0, return_num=True) # if empty or single volume, return msk if num <= 1: return msk # compute the volume unq_labels,vol = np.unique(labels, return_counts=True) # remove bg unq_labels = unq_labels[1:] vol = vol[1:] # compute the expected volume # expected_vol = np.sum(np.square(vol))/np.sum(vol) # min_vol = expected_vol * thres_rate min_vol = thres_rate*otsu_thresholding(vol) # keep only the labels for which the volume is big enough unq_labels = unq_labels[vol > min_vol] # compile the selected volumes into 1 image s = (labels==unq_labels[0]) for i in range(1,len(unq_labels)): s += (labels==unq_labels[i]) return s
[docs] def keep_biggest_volume_centered(msk:np.ndarray)->np.ndarray: """ Return a mask with only the connected component closest to the image center, provided its volume is not too small compared to the largest connected component. Otherwise, return the largest connected component. "Too small" means its volume is less than half of the largest component. The returned mask intensities are either 0 or `msk.max()`. Parameters ---------- msk : numpy.ndarray Input binary mask. Returns ------- numpy.ndarray Mask with only one connected component kept. """ labels, num = measure.label(msk, background=0, return_num=True) if num <= 1: # if only one volume, no need to remove something return msk close_idx = closest(labels,num) vol = volumes(labels) relative_vol = [vol[close_idx]/vol[idx] for idx in range(1,len(vol))] # bug fix, empty prediction (it should not happen) if len(relative_vol)==0: return msk min_rel_vol = np.min(relative_vol) if min_rel_vol < 0.5: close_idx = np.argmin(relative_vol)+1 return (labels==close_idx).astype(msk.dtype)*msk.max()